DSA 2025 S2 Sample Assignment¶
IMPORTANT NOTE TO FUTURE STUDENTS: Most of the comments have been removed from this student's submission, to avoid the temptation for future students to copy these comments in their future assignment submissions. It is important that you explain why you are undertaking each step in your notebook and that you interpret the output of each step in your own words.
Purpose:¶
The purpose of this assignment is to assist Betahelf, a large healthcare provider operating across the US, in developing a healthcare program that proactively targets patients who are likely to receive an acute diagnosis over the next 12 months.
In particular, Betahelf is seeking advice in the following three areas:
Are Betahelf's pathologists providing value-addedness in their commentaries of lab result reports?
Is it possible to predict which of Betahelf's current patients will receive an acute diagnosis over the next 12 months?
Is it possible to split the causes of acute diagnoses between broad environmental factors (e.g. socio-economic status) and patient-specific factors (e.g. a history of acute disease or poor lifestyle choices)?
Packages¶
This section imports the packages that are used in this notebook.
# Package 1
# For mounting Google Drive
from google.colab import drive
# Package 2
# For data manipulation
import pandas as pd
# Package 3
# For text cleaning
import re
# Package 4
# For name redacting
import spacy
# Packages 5 to 8
# For tokenising and lemmatising
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from sklearn.feature_extraction import text as sk_text
nltk.download('punkt_tab')
nltk.download('wordnet')
nltk.download('stopwords')
# Package 9
# For counting
from collections import Counter
# Package 10
# For generating embeddings
!pip install -U sentence-transformers --quiet
from sentence_transformers import SentenceTransformer
# Packages 11 and 12
# For checking embeddings and TF-IDF vectorisation
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Package 13
# For TF-IDF vectorisation
from sklearn.feature_extraction.text import TfidfVectorizer
# Package 14
# For plotting of graphs
import matplotlib.pyplot as plt
# Packages 15 and 16
# For PCA
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# Packages 17 to 20
# For agglomerative clustering
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import davies_bouldin_score, silhouette_score, calinski_harabasz_score, pairwise_distances, silhouette_samples
import matplotlib.cm as cm
from sklearn.manifold import TSNE
# Packages 21 and 22
# For random forest
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import PartialDependenceDisplay
# Package 23
# For generating word clouds
!pip install wordcloud --quiet
from wordcloud import WordCloud
# Package 24
# For sentiment calculations
!pip install -U transformers --quiet
from transformers import pipeline
# Package 25
# For saving and loading DataFrames
import pickle
# Package 26
# For manipulating dates
from datetime import datetime
# Package 27
# For data splitting
import hashlib
# Packages 28 to 32
# For neural network building and benchmark comparisons
from sklearn.metrics import precision_score, recall_score, roc_auc_score, roc_curve, confusion_matrix
import random
import os
!pip install tensorflow --quiet
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential, load_model
from sklearn.utils import class_weight
!pip install shap --quiet
import shap
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier, plot_tree
# Packages 33 to 36
# For LLM prompt generation
from typing import List, Dict, Optional
!pip install transformers accelerate torch --upgrade --quiet
import torch
from transformers import pipeline
[nltk_data] Downloading package punkt_tab to /root/nltk_data... [nltk_data] Unzipping tokenizers/punkt_tab.zip. [nltk_data] Downloading package wordnet to /root/nltk_data... [nltk_data] Downloading package stopwords to /root/nltk_data... [nltk_data] Unzipping corpora/stopwords.zip.
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/486.6 kB ? eta -:--:-- ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 481.3/486.6 kB 21.1 MB/s eta 0:00:01 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 486.6/486.6 kB 13.8 MB/s eta 0:00:00
/usr/local/lib/python3.12/dist-packages/torch_xla/experimental/gru.py:113: SyntaxWarning: invalid escape sequence '\_'
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
/usr/local/lib/python3.12/dist-packages/jax/_src/cloud_tpu_init.py:82: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled")
warnings.warn(
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 539.2/539.2 kB 14.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 40.1/40.1 kB 3.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.6/11.6 MB 162.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 620.7/620.7 MB 802.6 kB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.5/57.5 kB 6.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.5/24.5 MB 66.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 97.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 98.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 22.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 72.5/72.5 kB 7.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 23.8 MB/s eta 0:00:00
Functions¶
This section sets out user defined functions that are used in this notebook.
If you prefer, you may instead choose to define these under the relevant assignment question section below.
# Function 1
# Function 2/3/4 etc.
Import data¶
This section imports the data that is used in this notebook.
Reference: DSA 2025 S2 Assignment Data
# Mount Google Drive and import dataset from Google Drive
!pip install openpyxl --quiet
if 'assignmentdata' not in globals():
drive.mount('/content/gdrive/')
infolder = '/content/gdrive/My Drive/DSA Assignment Data/'
filename = 'DSA 2025 S2 Assignment Data.xlsx'
# Specify dtype for ICD9Code to ensure it is read as a string
assignmentdata = pd.read_excel(infolder + filename, sheet_name = None, engine = 'openpyxl', dtype = {'ICD9Code': object})
# Check number of imported tables
print(f'Loaded {len(assignmentdata)} tables.')
# Read sheets into separate DataFrames
for name, df in assignmentdata.items():
globals()[name.lower() + '_df'] = df
Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).
Loaded 11 tables.
Q1 - Explore and examine the pathology comments¶
Q1a - Vectorisation¶
Purpose:¶
The purpose of this section is to use both embeddings and TF-IDF vectorisation to calculate vectorised features after cleaning the 'written_report' column of the 'Pathology' table, with a focus on the commentaries.
References:¶
Outline any references you have used in this section of your notebook here.
Q1a Step 1 - Text Cleaning¶
# Q1a Step 1
# Set random seed so that results can be replicated
random_seed = 41399
# Initial EDA of "Pathology" table
# Set column width to unlimited to view more text
pd.set_option('display.max_colwidth', None)
# View first five rows to get a glimpse of the table structure/contents
pathology_df.head()
# No blank fields in "written_report" column
(pathology_df['written_report'].isnull() | (pathology_df['written_report'] == 'NaN')).sum()
# View 5 random reports to gain any initial insights
# Some immediate observations:
# 1. The reports appear to start with the findings from the tests conducted,
# followed by a "Comments" section and, sometimes, a "Recommendations" section
# (e.g. row 7197). Other times, this distinction is not explicit and there
# appears to be reports with recommendations in the "Comments" section (e.g. row
# 7444).
# 2. As it is reasonable to think of recommendations as a form of commentary,
# the assumption shall be made, going forward, that recommendations within the
# "Comments" section, as well as those in an explicit "Recommendations" section,
# shall be considered as part of the commentary to be analysed in Q1.
pathology_df['written_report'].sample(n = 5, random_state = random_seed)
# Function to shorten written report
def shorten_written_report(text):
# Remove everything before the "Comments" section
text = re.sub('^(.*Comments:)', '', text, flags = re.DOTALL | re.IGNORECASE)
text = re.sub(r'^(.*Comments\*\*)', '', text, flags = re.DOTALL | re.IGNORECASE)
return text
# Load Spacy's Named Entity Recognition (NER) model
# Disable other features to decrease runtime
NER = spacy.load("en_core_web_sm", disable = ["tok2vec", "tagger", "parser", "attribute_ruler", "lemmatizer"])
# Function to redact names using Spacy's NER
# Exclude the following words from being redacted
whitelist_redact_names = ['dL', 'Bilirubin']
def redact_names(text):
doc = NER(text)
for ent in doc.ents:
if ent.label_ == "PERSON" and ent.text not in whitelist_redact_names:
text = text.replace(ent.text, "{PATIENT}")
return text
# Initialise stopwords
stop_words = stopwords.words('english')
# Extend list of stopwords from scikit package and own list
stop_words.extend(sk_text.ENGLISH_STOP_WORDS.union(['patient', 'doctor', 'clinic', 'clinical', 'pathologist', 'visit', 'appointment', 'report', 'result', 'level', 'range', 'value', 'test', 'testing', 'finding', 'found', 'indicate', 'indicating', 'indicated', 'suggest', 'suggesting', 'blood', 'overall', 'consult', 'consults', 'consultation', 'function', 'let', 'healthcare', 'provider', 'include', 'including', 'included', 'regarding', 'physician', 'recommend', 'recommends', 'recommendation', 'recommended', 'consider', 'applicable', 'health', 'evaluate', 'evaluates', 'evaluation', 'management', 'assess', 'assessment', 'assessed', 'necessary', 'ensure', 'outcome', 'know', 'required', 'directly', 'related', 'given', 'need', 'available', 'lifestyle', 'factor', 'referring', 'needed', 'determine', 'cause', 'monitor', 'monitoring', 'mg', 'dl', 'status', 'count', 'condition', 'important', 'ratio', 'modification', 'reference', 'follow', 'issue', 'impact', 'total', 'additional', 'investigation', 'based', 'previous', 'warrant', 'feel', 'free', 'mr', 'date', 'birth']))
# Remove duplicated stopwords for more efficient checking
stop_words = set(stop_words)
# Check for stopwords that should consider excluding
for words in sorted(stop_words):
print(words)
# Exclude the following stopwords
whitelist_stop_words = ['above', 'below', 'serious', 'no']
# Final list of stopwords
stop_words = [word for word in stop_words if word not in whitelist_stop_words]
# Function to clean text
def pre_process_written_report(text):
# 1. Lowercase text
text = text.lower()
# 2. Replace "\n" with whitespace
text = text.replace('\n', ' ')
# 3. Remove signatures (e.g. of the pathologist as in row 7444)
text = re.sub(r'\[Signature\].*?$', '', text, flags = re.DOTALL | re.IGNORECASE)
# 4. Remove redacted names
text = re.sub(r'\{PATIENT\}', '', text, flags = re.DOTALL | re.IGNORECASE)
# 5. Expand abbreviations and contractions
text = re.sub(r'n\'t', ' not', text)
text = re.sub(r'\'re', ' are', text)
text = re.sub(r'\'s', ' is', text)
text = re.sub(r'\'d', ' would', text)
text = re.sub(r'\'ll', ' will', text)
text = re.sub(r'\'ve', ' have', text)
text = re.sub(r'\'m', ' am', text)
# 6. Remove all non-alphanumeric characters
# text = re.sub(r'[^\w\s]', ' ', text, flags = re.DOTALL | re.IGNORECASE)
# 6. Remove all non-alpha characters
text = re.sub(r'[^a-zA-Z\s]', ' ', text, flags = re.DOTALL | re.IGNORECASE)
# 7. Remove leading and trailing whitespaces
text = text.strip()
# 8. Tokenise the text
tokens = nltk.word_tokenize(text)
# 9. Remove stopwords and lemmatise remaining words
lemmatised_tokens = [WordNetLemmatizer().lemmatize(token) for token in tokens if token not in stop_words]
# 10. Remove stopwards again after lemmatising
lemmatised_tokens_without_stopwords = [token for token in lemmatised_tokens if token not in stop_words]
return " ".join(lemmatised_tokens_without_stopwords)
# Only apply redaction for embeddings (after shortening)
pathology_df['redacted_commentary'] = pathology_df['written_report'].apply(shorten_written_report).apply(redact_names)
# Apply both redaction and pre-processing for TF-IDF (after shortening)
pathology_df['tfidf_commentary'] = pathology_df['written_report'].apply(shorten_written_report).apply(redact_names).apply(pre_process_written_report)
# Function to find rows not in UTF-8 encoding
def is_utf8(text):
try:
text.encode('utf-8').decode('utf-8', 'strict')
return True
except UnicodeError:
return False
changed_rows = []
for index, row in pathology_df.iterrows():
description = row['redacted_commentary']
if isinstance(description, str) and not is_utf8(description):
try:
encoded_description = description.encode('utf-8', 'replace').decode('utf-8')
pathology_df.loc[index, 'redacted_commentary'] = encoded_description
# Track indexes of changed rows
changed_rows.append(index)
except Exception as e:
print(f"Error encoding row {index}: {e}")
# Show rows that were changed due to the encoding
if changed_rows:
print(f"\n{len(changed_rows)} rows were changed due to encoding")
display(pathology_df.loc[changed_rows])
else:
print("\nNo rows were changed due to encoding.")
# Sample check 5 random report commentaries to verify above data cleaning steps
display(pathology_df[['written_report', 'redacted_commentary', 'tfidf_commentary']].sample(n = 5, random_state = random_seed))
# Check 'tfidf_commentary' top 20 word frequencies to check for anymore stopwords
pathology_df['tfidf_commentary'].str.split(expand = True).stack().value_counts().head(20)
a about above across additional after afterwards again against ain all almost alone along already also although always am among amongst amoungst amount an and another any anyhow anyone anything anyway anywhere applicable appointment are aren aren't around as assess assessed assessment at available back based be became because become becomes becoming been before beforehand behind being below beside besides between beyond bill birth blood both bottom but by call can cannot cant cause clinic clinical co con condition consider consult consultation consults could couldn couldn't couldnt count cry d date de describe detail determine did didn didn't directly dl do doctor does doesn doesn't doing don don't done down due during each eg eight either eleven else elsewhere empty enough ensure etc evaluate evaluates evaluation even ever every everyone everything everywhere except factor feel few fifteen fifty fill find finding fire first five follow for former formerly forty found four free from front full function further get give given go had hadn hadn't has hasn hasn't hasnt have haven haven't having he he'd he'll he's health healthcare hence her here hereafter hereby herein hereupon hers herself him himself his how however hundred i i'd i'll i'm i've ie if impact important in inc include included including indeed indicate indicated indicating interest into investigation is isn isn't issue it it'd it'll it's its itself just keep know last latter latterly least less let level lifestyle ll ltd m ma made management many may me meanwhile mg might mightn mightn't mill mine modification monitor monitoring more moreover most mostly move mr much must mustn mustn't my myself name namely necessary need needed needn needn't neither never nevertheless next nine no nobody none noone nor not nothing now nowhere o of off often on once one only onto or other others otherwise our ours ourselves out outcome over overall own part pathologist patient per perhaps physician please previous provider put range rather ratio re recommend recommendation recommended recommends reference referring regarding related report required result s same see seem seemed seeming seems serious several shan shan't she she'd she'll she's should should've shouldn shouldn't show side since sincere six sixty so some somehow someone something sometime sometimes somewhere status still such suggest suggesting system t take ten test testing than that that'll the their theirs them themselves then thence there thereafter thereby therefore therein thereupon these they they'd they'll they're they've thick thin third this those though three through throughout thru thus to together too top total toward towards twelve twenty two un under until up upon us value ve very via visit warrant was wasn wasn't we we'd we'll we're we've well were weren weren't what whatever when whence whenever where whereafter whereas whereby wherein whereupon wherever whether which while whither who whoever whole whom whose why will with within without won won't would wouldn wouldn't y yet you you'd you'll you're you've your yours yourself yourselves No rows were changed due to encoding.
| written_report | redacted_commentary | tfidf_commentary | |
|---|---|---|---|
| 7444 | **Pathology Report:**\nThe blood test results for Rikki Caldwell show the following abnormal results:\n1. **Bilirubin:**\n - Observation Value: 0.3 mg/dL\n - Reference Range: 0.2-1.2 mg/dL\n - Abnormal Flags: NA\n - Is Abnormal Value: FALSE\n2. **Total Protein:**\n - Observation Value: 7.6 g/dL\n - Reference Range: 6.4-8.3 g/dL\n - Abnormal Flags: NA\n - Is Abnormal Value: FALSE\n**Comments:**\nThe bilirubin level of 0.3 mg/dL falls within the normal reference range of 0.2-1.2 mg/dL, indicating normal liver function. Similarly, the total protein level of 7.6 g/dL is within the reference range of 6.4-8.3 g/dL, suggesting no abnormalities in this parameter.\nGiven that the patient is a previous smoker with a history of smoking, it is important to monitor liver function regularly as smoking can impact liver health. However, based on the current results, there are no immediate concerns regarding liver function or overall protein levels.\nI recommend continued monitoring of liver function and regular follow-up appointments to assess any changes over time, especially considering the patient's smoking history.\nPlease let me know if further tests or consultations are required.\n[Signature] \nClinical Pathologist | **\nThe bilirubin level of 0.3 mg/dL falls within the normal reference range of 0.2-1.2 mg/dL, indicating normal liver function. Similarly, the total protein level of 7.6 g/dL is within the reference range of 6.4-8.3 g/dL, suggesting no abnormalities in this parameter.\nGiven that the patient is a previous smoker with a history of smoking, it is important to monitor liver function regularly as smoking can impact liver health. However, based on the current results, there are no immediate concerns regarding liver function or overall protein levels.\nI recommend continued monitoring of liver function and regular follow-up appointments to assess any changes over time, especially considering the patient's smoking history.\nPlease let me know if further tests or consultations are required.\n[Signature] \nClinical Pathologist | bilirubin fall normal normal liver similarly protein g g no abnormality parameter smoker history smoking liver regularly smoking liver current no immediate concern liver protein continued liver regular change time especially considering smoking history |
| 7197 | Pathology Report:\nPatient: Jennifer Page\nDate of Birth: August 5, 2003\nSmoking Status: Non-smoker\nAbnormal Results:\n1. Chloride, Serum: 98.0 mmol/L (Reference Range: 3.5-5.2 mmol/L)\n2. Globulin: 2.6 g/dL (Reference Range: 1.1-2.5 g/dL)\n3. Nitrite: 102.0 x10E3/uL (Reference Range: 0.1-1.0 x10E3/uL)\n4. Hemoglobin: 5.1 g/dL (Reference Range: 12.0-16.0 g/dL)\n5. Triglyceride: 342.0 mg/dL (Reference Range: 1.8-7.8 mg/dL)\nComments:\nThe blood test results for Jennifer Page show several abnormal findings that warrant further investigation. The elevated Chloride level may indicate dehydration or kidney issues. The high Globulin level suggests inflammation or infection in the body. The Nitrite levels are significantly elevated, which may indicate a urinary tract infection. The low Hemoglobin level could signify anemia, which needs to be addressed promptly. The high Triglyceride level is a risk factor for cardiovascular disease.\nAs Jennifer is a non-smoker, her abnormal results are not directly related to smoking. However, lifestyle factors such as diet and exercise may contribute to some of these abnormal findings. Further medical evaluation and possible treatments are recommended based on these results. Please consult with Jennifer's healthcare provider to address these abnormal findings promptly.\nRecommendations:\n1. Follow up with a healthcare provider for further evaluation and management of the abnormal results.\n2. Consider lifestyle modifications such as a healthy diet and regular exercise to improve overall health.\n3. Monitor and retest any abnormal findings to track progress and effectiveness of treatments. | \nThe blood test results for {PATIENT} show several abnormal findings that warrant further investigation. The elevated Chloride level may indicate dehydration or kidney issues. The high Globulin level suggests inflammation or infection in the body. The Nitrite levels are significantly elevated, which may indicate a urinary tract infection. The low {PATIENT} level could signify anemia, which needs to be addressed promptly. The high Triglyceride level is a risk factor for cardiovascular disease.\nAs {PATIENT} is a non-smoker, her abnormal results are not directly related to smoking. However, lifestyle factors such as diet and exercise may contribute to some of these abnormal findings. Further medical evaluation and possible treatments are recommended based on these results. Please consult with {PATIENT}'s healthcare provider to address these abnormal findings promptly.\nRecommendations:\n1. Follow up with a healthcare provider for further evaluation and management of the abnormal results.\n2. Consider lifestyle modifications such as a healthy diet and regular exercise to improve overall health.\n3. Monitor and retest any abnormal findings to track progress and effectiveness of treatments. | abnormal elevated chloride dehydration kidney high globulin suggests inflammation infection body nitrite significantly elevated urinary tract infection low signify anemia addressed promptly high triglyceride risk cardiovascular disease non smoker abnormal smoking diet exercise contribute abnormal medical possible treatment address abnormal promptly abnormal healthy diet regular exercise improve retest abnormal track progress effectiveness treatment |
| 6744 | **Pathology Report**\n**Patient Information:**\n- Name: Sara Hartzler\n- Date of Birth: November 18, 1982\n- Smoking Status: Not Available\n**Laboratory Results:**\n1. Albumin / Globulin Ratio: Not Available\n2. Bilirubin: 0.6 mg/dL (Reference Range: 97-108 mg/dL)\n**Comments:**\nThe results indicate a normal bilirubin level of 0.6 mg/dL, within the reference range of 97-108 mg/dL. However, the Albumin / Globulin Ratio result is not available for interpretation.\nAs the patient's smoking status is not provided, and the available results do not directly correlate with smoking-related conditions, further evaluation would be needed to assess any potential impact on the test results. It is recommended to follow up with the patient for additional information and repeat the Albumin / Globulin Ratio test for a comprehensive assessment.\nFurther investigations or consultations may be necessary based on the clinical presentation and the current laboratory results to provide a comprehensive evaluation and appropriate management plan for the patient. | **\nThe results indicate a normal bilirubin level of 0.6 mg/dL, within the reference range of 97-108 mg/dL. However, the Albumin / Globulin Ratio result is not available for interpretation.\nAs the patient's smoking status is not provided, and the available results do not directly correlate with smoking-related conditions, further evaluation would be needed to assess any potential impact on the test results. It is recommended to follow up with the patient for additional information and repeat the Albumin / Globulin Ratio test for a comprehensive assessment.\nFurther investigations or consultations may be necessary based on the clinical presentation and the current laboratory results to provide a comprehensive evaluation and appropriate management plan for the patient. | normal bilirubin albumin globulin interpretation smoking provided correlate smoking potential information repeat albumin globulin comprehensive presentation current laboratory provide comprehensive appropriate plan |
| 1770 | **Pathology Report:**\n**Patient Information:**\n- Name: Karen Danielson\n- Date of Birth: May 20, 1989\n- Smoking Status: NA\n**Abnormal Results:**\n1. **Cholesterol / HDL Ratio:** \n - Observation Value: 130.0 K/uL\n - Reference Range: 6.4-8.5\n - Abnormal: No\n - Implication: The cholesterol/HDL ratio is within the normal range.\n \n2. **Urinalysis Reflex:**\n - Observation Value: 0.6 thou/cmm\n - Abnormal: No\n - Implication: The urinalysis reflex result is within normal limits.\n \n3. **Chloride, Serum:**\n - Observation Value: 107.0 mmol/L\n - Reference Range: 3.5-5.2\n - Abnormal: No\n - Implication: The chloride level in the serum is within the normal range.\n \n4. **Follicle Stimulating Hormone:**\n - Observation Value: 292.0 x10E3/uL\n - Abnormal: No\n - Implication: The Follicle Stimulating Hormone level is within normal limits.\n5. **Bilirubin:**\n - Observation Value: 0.4 K/uL, 0.3 mg/dL\n - Reference Range: None seen/Few, 97-108\n - Abnormal: No\n - Implication: Both bilirubin levels are within normal limits.\n6. **Cholesterol / HDL Ratio:** \n - Observation Value: 0.4 K/uL\n - Reference Range: 6.4-8.5\n - Abnormal: No\n - Implication: The cholesterol/HDL ratio is within the normal range.\n**Comments:**\nOverall, most of the blood test results for Karen Danielson are within normal limits, indicating a generally healthy status. Given the patient's date of birth and current non-smoking status, it is important for her to continue with a healthy lifestyle, including regular exercise and a balanced diet. Monitoring these blood parameters regularly will also be beneficial for maintaining her health status.\nIt is important to note that as a clinical pathologist, I do not have information regarding the patient's full medical history. If there are any concerns or additional tests required, consultation with the referring physician is recommended. | **\nOverall, most of the blood test results for {PATIENT} are within normal limits, indicating a generally healthy status. Given the patient's date of birth and current non-smoking status, it is important for her to continue with a healthy lifestyle, including regular exercise and a balanced diet. Monitoring these blood parameters regularly will also be beneficial for maintaining her health status.\nIt is important to note that as a clinical pathologist, I do not have information regarding the patient's full medical history. If there are any concerns or additional tests required, consultation with the referring physician is recommended. | normal limit generally healthy current non smoking continue healthy regular exercise balanced diet parameter regularly beneficial maintaining note information medical history concern |
| 11488 | ## Pathology Report:\n### Patient Information:\n- **Name:** Kathleen Lehmann\n- **Date of Birth:** March 19, 1945\n- **Smoking Status:** NA\n### Test Results:\n1. **Albumin / Globulin Ratio:** 1.5 (Reference Range: 6.0-8.5)\n2. **Bilirubin:** 0.5 mg/dL (Reference Range: 0.3-1.0)\n3. **Chloride, Serum:** 101.0 mmol/L (Reference Range: 96-106)\n4. **Globulin:** 2.8 g/dL (Reference Range: 2.3-3.5)\n5. **Hematocrit:** 40.6% (Reference Range: 38.3-48.6)\n6. **Hemoglobin:** 13.8 g/dL (Reference Range: 12-15)\n7. **Monocytes:** 0.7 x10E3/uL (Reference Range: 0.1-0.8)\n8. **Neutrophils:** 4.2 x10E3/uL (Reference Range: 2.0-7.0)\n9. **Platelet Count:** 275.0 x10E3/uL (Reference Range: 150-400) - **Above Normal High**\n10. **Potassium, Serum:** 3.4 mmol/L (Reference Range: 3.5-5.0)\n11. **Total Protein:** 7.0 g/dL (Reference Range: 6.3-7.9)\n12. **Triglycerides:** 154.0 mg/dL (Reference Range: 70-150)\n### Comments:\n- The Platelet Count is **above normal high**, which could be an indication of various conditions, including inflammatory disorders, blood disorders, or even potentially cancer. Further investigation may be warranted.\n- The Potassium level is slightly below the reference range, which could be a cause for concern and should be monitored.\n- No direct correlation between the test results and smoking status was identified in this case.\n- Overall, the majority of the test results are within normal limits, with only slight variations in some parameters.\n### Recommendations:\n1. Further evaluation of the platelet count deviation from normal range may be necessary for possible underlying conditions.\n2. Monitor potassium levels closely and consider dietary adjustments if needed.\n3. Follow up with the patient for routine check-ups to ensure overall health and well-being. | \n- The Platelet Count is **above normal high**, which could be an indication of various conditions, including inflammatory disorders, blood disorders, or even potentially cancer. Further investigation may be warranted.\n- The Potassium level is slightly below the reference range, which could be a cause for concern and should be monitored.\n- No direct correlation between the test results and smoking status was identified in this case.\n- Overall, the majority of the test results are within normal limits, with only slight variations in some parameters.\n### Recommendations:\n1. Further evaluation of the platelet count deviation from normal range may be necessary for possible underlying conditions.\n2. Monitor potassium levels closely and consider dietary adjustments if needed.\n3. Follow up with the patient for routine check-ups to ensure overall health and well-being. | platelet above normal high indication various inflammatory disorder disorder potentially cancer warranted potassium slightly below concern monitored no direct correlation smoking identified case majority normal limit slight variation parameter platelet deviation normal possible underlying potassium closely dietary adjustment routine check ups |
| count | |
|---|---|
| smoking | 28246 |
| elevated | 13815 |
| abnormal | 12868 |
| normal | 12642 |
| liver | 11256 |
| underlying | 11108 |
| potential | 8529 |
| triglyceride | 7110 |
| abnormality | 7045 |
| smoker | 6720 |
| globulin | 6674 |
| risk | 6482 |
| regular | 6179 |
| history | 5654 |
| bilirubin | 5539 |
| protein | 5318 |
| low | 5304 |
| cessation | 4943 |
| disease | 4741 |
| cardiovascular | 4682 |
Q1a Step 2 - Embeddings and TF-IDF Vectorisation¶
Embeddings¶
# Q1a Step 2
# Generate embeddings for the 'redacted_commentary' column
model = SentenceTransformer('xlreator/biosyn-biobert-snomed')
embeddings = model.encode(pathology_df['redacted_commentary'].tolist())
# Check that the number of rows of the embeddings is the same as the number of rows in pathology_df
if len(embeddings) == len(pathology_df):
print("Number of rows in embeddings is the same as in pathology_df.")
else:
print("Number of rows in embeddings is not the same as in pathology_df.")
# Concatenate the embeddings with the original data
pathology_df = pd.concat([pathology_df, pd.DataFrame(embeddings)], axis = 1)
# Append "embedding" as prefix to columns with integer names for clarity
for col in pathology_df.columns:
if isinstance(col, int):
pathology_df.rename(columns = {col: f"embedding_{col}"}, inplace = True)
# Check for NaN values in embedding columns
print("Total NaN values in embeddings:", pathology_df[[col for col in pathology_df.columns if col.startswith("embedding_")]].isna().sum().sum())
# Check for 0 values in embedding columns
print("Total zero values in embeddings:", (pathology_df[[col for col in pathology_df.columns if col.startswith("embedding_")]] == 0).sum().sum())
# Show default columns
pd.set_option('display.max_columns', 0)
# Calculate cosine similarity between all embedding vectors
embedding_similarity_matrix = cosine_similarity(pathology_df[[col for col in pathology_df.columns if col.startswith("embedding_")]].values)
# Find the two rows with the highest similarity excluding self-similarity
np.fill_diagonal(embedding_similarity_matrix, -1)
most_similar_indices = np.unravel_index(np.argmax(embedding_similarity_matrix), embedding_similarity_matrix.shape)
# Display the two most similar rows
print("Two most similar rows based on embeddings")
display(pathology_df.iloc[[most_similar_indices[0], most_similar_indices[1]]])
# Find the two rows with the lowest similarity excluding self-similarity
np.fill_diagonal(embedding_similarity_matrix, 1)
most_distant_indices = np.unravel_index(np.argmin(embedding_similarity_matrix), embedding_similarity_matrix.shape)
# Display the two most distant rows
print("Two most distant rows based on embeddings:")
display(pathology_df.iloc[[most_distant_indices[0], most_distant_indices[1]]])
# Histogram of cosine similarity between all embedding vectors
plt.figure(figsize = (8, 5))
plt.hist(embedding_similarity_matrix[embedding_similarity_matrix != 0], bins = 20, edgecolor = 'black')
plt.xlabel("Cosine Similarity")
plt.ylabel("Frequency")
plt.title("Distribution of Cosine Similarity based on Embeddings (excluding self-similarity)")
plt.grid()
plt.show()
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
modules.json: 0%| | 0.00/229 [00:00<?, ?B/s]
config_sentence_transformers.json: 0%| | 0.00/123 [00:00<?, ?B/s]
README.md: 0.00B [00:00, ?B/s]
sentence_bert_config.json: 0%| | 0.00/52.0 [00:00<?, ?B/s]
config.json: 0%| | 0.00/651 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/433M [00:00<?, ?B/s]
tokenizer_config.json: 0.00B [00:00, ?B/s]
vocab.txt: 0.00B [00:00, ?B/s]
tokenizer.json: 0.00B [00:00, ?B/s]
special_tokens_map.json: 0%| | 0.00/695 [00:00<?, ?B/s]
config.json: 0%| | 0.00/270 [00:00<?, ?B/s]
Number of rows in embeddings is the same as in pathology_df. Total NaN values in embeddings: 0 Total zero values in embeddings: 0 Two most similar rows based on embeddings
| LabResultGuid | written_report | redacted_commentary | tfidf_commentary | embedding_0 | embedding_1 | embedding_2 | embedding_3 | embedding_4 | embedding_5 | embedding_6 | embedding_7 | embedding_8 | embedding_9 | embedding_10 | embedding_11 | embedding_12 | embedding_13 | embedding_14 | embedding_15 | embedding_16 | embedding_17 | embedding_18 | embedding_19 | embedding_20 | embedding_21 | embedding_22 | embedding_23 | embedding_24 | embedding_25 | embedding_26 | embedding_27 | embedding_28 | embedding_29 | embedding_30 | embedding_31 | embedding_32 | embedding_33 | embedding_34 | embedding_35 | ... | embedding_728 | embedding_729 | embedding_730 | embedding_731 | embedding_732 | embedding_733 | embedding_734 | embedding_735 | embedding_736 | embedding_737 | embedding_738 | embedding_739 | embedding_740 | embedding_741 | embedding_742 | embedding_743 | embedding_744 | embedding_745 | embedding_746 | embedding_747 | embedding_748 | embedding_749 | embedding_750 | embedding_751 | embedding_752 | embedding_753 | embedding_754 | embedding_755 | embedding_756 | embedding_757 | embedding_758 | embedding_759 | embedding_760 | embedding_761 | embedding_762 | embedding_763 | embedding_764 | embedding_765 | embedding_766 | embedding_767 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 9284 | b8e629c9-5b07-44f5-b120-364c5eb09687 | **Pathology Report for Lauren Jones**\n**Patient Information:**\n- Name: Lauren Jones\n- Date of Birth: March 28, 1994\n- Smoking Status: Non-smoker\n**Abnormal Results:**\n1. **Ketones:** The ketones level is elevated at 35.8%, which may indicate a metabolic imbalance or insufficient carbohydrate intake. Further evaluation may be needed.\n2. **Protein Total:** The total protein level is elevated at 7.1 g/dL, above the reference range of 0.0-1.2 g/dL. This could be a sign of dehydration or certain medical conditions that affect protein levels in the blood.\n3. **Triglyceride:** The triglyceride level is elevated at 8.1 mg/dL, above the reference range of 1.8-7.8 mg/dL. This could indicate an increased risk of cardiovascular disease and may require lifestyle modifications.\n4. **Triglyceride (again):** Another triglyceride result shows a significantly elevated level of 12.6 mmol/L, which is concerning and requires further investigation and management.\n**Comments:**\nBased on the abnormal results in the blood test, further evaluation and follow-up may be necessary to determine the underlying cause of these abnormalities. It is recommended to consult with a healthcare provider for appropriate management and guidance. Additionally, lifestyle modifications such as dietary changes and regular exercise may be beneficial, especially in the case of elevated triglyceride levels. As Lauren Jones is a non-smoker, it is important to address these abnormal results promptly to promote overall health and well-being. | **\nBased on the abnormal results in the blood test, further evaluation and follow-up may be necessary to determine the underlying cause of these abnormalities. It is recommended to consult with a healthcare provider for appropriate management and guidance. Additionally, lifestyle modifications such as dietary changes and regular exercise may be beneficial, especially in the case of elevated triglyceride levels. As {PATIENT} is a non-smoker, it is important to address these abnormal results promptly to promote overall health and well-being. | abnormal underlying abnormality appropriate guidance additionally dietary change regular exercise beneficial especially case elevated triglyceride non smoker address abnormal promptly promote | -0.22112 | 0.110824 | 0.462621 | -0.677728 | 0.141288 | 0.439689 | 0.339617 | -0.338564 | -0.758425 | -0.025648 | 0.745713 | 0.702073 | -0.178007 | -0.110599 | -0.673737 | -0.39933 | 0.085955 | 0.475142 | 0.669558 | 0.313818 | -0.081713 | -0.336667 | -0.708475 | 0.181716 | -0.260547 | -0.191626 | -0.549756 | -0.060871 | 0.259309 | 0.156446 | -0.810225 | -0.184567 | -0.122687 | 0.271737 | 0.297283 | -0.480869 | ... | 0.33785 | 0.495703 | 0.295427 | 0.299392 | -0.325809 | 0.429643 | -0.938922 | 0.259923 | 0.247953 | -0.076533 | -0.307496 | 0.431351 | 0.367086 | -0.295085 | 0.004928 | 0.033894 | -0.015974 | 0.14501 | -0.14337 | 0.451492 | 0.240021 | 0.03326 | -0.028358 | 0.144388 | -0.29338 | -0.062095 | 0.059938 | -0.912561 | -0.328414 | 0.472607 | -0.16557 | -0.734162 | -0.062211 | 0.383606 | 0.118487 | 0.274979 | -0.619662 | -0.537403 | 0.062667 | 0.143218 |
| 11186 | 44f166d1-879d-4961-ba8e-4342793dfcf6 | ## **Pathology Report:**\n### **Patient Information:**\n- **Patient Name:** Terry Stuart\n- **Date of Birth:** March 22, 2000\n- **Smoking Status:** Non-smoker or less than 100 cigarettes in a lifetime\n### **Abnormal Results:**\n1. **Albumin / Globulin Ratio:**\n - **Observation Value:** NA\n - **Reference Range:** NA\n - **Abnormal Flags:** Above Normal High\n - **Is Abnormal Value:** TRUE\n - **Implications:** The albumin/globulin ratio is higher than normal, indicating a potential liver or kidney issue.\n2. **Protein Total:**\n - **Observation Value:** 6.1 g/dL\n - **Reference Range:** 0.0-1.2 g/dL\n - **Abnormal Flags:** NA\n - **Is Abnormal Value:** FALSE\n - **Implications:** The total protein level is significantly elevated, which could be indicative of multiple myeloma or chronic inflammation.\n3. **Triglyceride:**\n - **Observation Value:** 178.0 mg/dL\n - **Reference Range:** 1.8-7.8 mg/dL\n - **Abnormal Flags:** NA\n - **Is Abnormal Value:** FALSE\n - **Implications:** The triglyceride level is elevated, suggesting a risk for cardiovascular disease or metabolic disorders.\n### **Comments:**\nBased on the abnormal results in the blood test, further evaluation and follow-up may be necessary to determine the underlying cause of these abnormalities. Given the patient's non-smoking status and the abnormal albumin/globulin ratio and total protein levels, additional tests and consultations with a specialist may be required to properly diagnose and address any potential health issues. Lifestyle modifications and dietary changes may also be beneficial in improving the abnormal results observed.\nIt is crucial for the patient to schedule a follow-up appointment with their healthcare provider to discuss these findings and develop a suitable management plan.\n---\nAs a Clinical Pathologist, it is important to provide a comprehensive analysis of the blood test results to guide further assessment and treatment for the patient. | **\nBased on the abnormal results in the blood test, further evaluation and follow-up may be necessary to determine the underlying cause of these abnormalities. Given the patient's non-smoking status and the abnormal albumin/globulin ratio and total protein levels, additional tests and consultations with a specialist may be required to properly diagnose and address any potential health issues. Lifestyle modifications and dietary changes may also be beneficial in improving the abnormal results observed.\nIt is crucial for the patient to schedule a follow-up appointment with their healthcare provider to discuss these findings and develop a suitable management plan.\n---\nAs a Clinical Pathologist, it is important to provide a comprehensive analysis of the blood test results to guide further assessment and treatment for the patient. | abnormal underlying abnormality non smoking abnormal albumin globulin protein specialist properly diagnose address potential dietary change beneficial improving abnormal observed crucial schedule discus develop suitable plan provide comprehensive analysis guide treatment | -0.22112 | 0.110824 | 0.462621 | -0.677728 | 0.141288 | 0.439689 | 0.339617 | -0.338564 | -0.758425 | -0.025648 | 0.745713 | 0.702073 | -0.178007 | -0.110599 | -0.673737 | -0.39933 | 0.085955 | 0.475142 | 0.669558 | 0.313818 | -0.081713 | -0.336667 | -0.708475 | 0.181716 | -0.260547 | -0.191626 | -0.549756 | -0.060871 | 0.259309 | 0.156446 | -0.810225 | -0.184567 | -0.122687 | 0.271737 | 0.297283 | -0.480869 | ... | 0.33785 | 0.495703 | 0.295427 | 0.299392 | -0.325809 | 0.429643 | -0.938922 | 0.259923 | 0.247953 | -0.076533 | -0.307496 | 0.431351 | 0.367086 | -0.295085 | 0.004928 | 0.033894 | -0.015974 | 0.14501 | -0.14337 | 0.451492 | 0.240021 | 0.03326 | -0.028358 | 0.144388 | -0.29338 | -0.062095 | 0.059938 | -0.912561 | -0.328414 | 0.472607 | -0.16557 | -0.734162 | -0.062211 | 0.383606 | 0.118487 | 0.274979 | -0.619662 | -0.537403 | 0.062667 | 0.143218 |
2 rows × 772 columns
Two most distant rows based on embeddings:
| LabResultGuid | written_report | redacted_commentary | tfidf_commentary | embedding_0 | embedding_1 | embedding_2 | embedding_3 | embedding_4 | embedding_5 | embedding_6 | embedding_7 | embedding_8 | embedding_9 | embedding_10 | embedding_11 | embedding_12 | embedding_13 | embedding_14 | embedding_15 | embedding_16 | embedding_17 | embedding_18 | embedding_19 | embedding_20 | embedding_21 | embedding_22 | embedding_23 | embedding_24 | embedding_25 | embedding_26 | embedding_27 | embedding_28 | embedding_29 | embedding_30 | embedding_31 | embedding_32 | embedding_33 | embedding_34 | embedding_35 | ... | embedding_728 | embedding_729 | embedding_730 | embedding_731 | embedding_732 | embedding_733 | embedding_734 | embedding_735 | embedding_736 | embedding_737 | embedding_738 | embedding_739 | embedding_740 | embedding_741 | embedding_742 | embedding_743 | embedding_744 | embedding_745 | embedding_746 | embedding_747 | embedding_748 | embedding_749 | embedding_750 | embedding_751 | embedding_752 | embedding_753 | embedding_754 | embedding_755 | embedding_756 | embedding_757 | embedding_758 | embedding_759 | embedding_760 | embedding_761 | embedding_762 | embedding_763 | embedding_764 | embedding_765 | embedding_766 | embedding_767 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1991 | af5459d5-9cbd-45b1-896b-86b29493514c | **Pathology Report:**\nThe blood test results for Stacey Smith indicate normal levels of Thyroxine and Cannabinoids. There are no abnormal flags for either of these lab results.\n**Comments:**\nGiven Stacey's date of birth in 1948 and her smoking status of 0 cigarettes per day (non-smoker or less than 100 in lifetime), it is important to note that smoking can have significant effects on various blood test results. However, in this case, Stacey's thyroid and cannabinoid levels appear to be within normal ranges, which is a positive finding.\nI recommend that Stacey continues to maintain a healthy lifestyle and regular visits with her healthcare provider to monitor her overall health.\nIf there are any specific concerns or symptoms that arise, further testing or consultation with a healthcare provider may be warranted. | **\nGiven {PATIENT}'s date of birth in 1948 and her smoking status of 0 cigarettes per day (non-smoker or less than 100 in lifetime), it is important to note that smoking can have significant effects on various blood test results. However, in this case, {PATIENT}'s thyroid and cannabinoid levels appear to be within normal ranges, which is a positive finding.\nI recommend that {PATIENT} continues to maintain a healthy lifestyle and regular visits with her healthcare provider to monitor her overall health.\nIf there are any specific concerns or symptoms that arise, further testing or consultation with a healthcare provider may be warranted. | smoking cigarette day non smoker lifetime note smoking significant effect various case thyroid cannabinoid appear normal positive continues maintain healthy regular specific concern symptom arise warranted | 0.050315 | -0.338145 | 0.123371 | -0.778844 | 0.289930 | -0.023063 | 0.06353 | -0.079109 | -0.666223 | 0.577689 | -0.326359 | 0.767733 | -0.527561 | -0.307863 | 1.204383 | 1.102557 | 0.727143 | 0.233890 | 0.121127 | -1.241660 | -0.370602 | -0.163484 | 0.270160 | 0.544093 | 0.178664 | -0.042546 | -0.093032 | 0.380952 | -0.688917 | 0.763480 | -1.208861 | -0.896460 | 0.158241 | 0.066066 | -0.352262 | -0.574730 | ... | 0.067837 | 0.683622 | -0.460075 | -0.038585 | 0.229846 | 0.873871 | 0.432226 | 0.121603 | -0.078207 | 0.291804 | -0.411891 | 0.163638 | -0.173809 | -0.429742 | 0.511413 | -0.495719 | 0.279363 | 0.205129 | -0.706432 | -0.807181 | 0.232116 | 0.224298 | 0.052357 | -0.303295 | 0.352407 | -0.322141 | 0.276664 | -0.241604 | 0.065989 | -0.305942 | -0.000573 | -0.019850 | -0.126656 | 0.264090 | 0.128549 | 0.446585 | -0.629358 | 0.073010 | 0.293735 | 0.308015 |
| 11365 | fb7e90f1-9b50-4758-95db-1e65f2bd5b10 | Pathology Report for Gayla Hamilton\nPatient Information:\n- Name: Gayla Hamilton\n- Date of Birth: July 26, 1984\n- Smoking Status: Non-smoker (0 cigarettes per day)\nLaboratory Results:\n1. Albumin / Globulin Ratio: 1.5 (Reference Range: 6.0-8.5)\n - Abnormal: Yes, below normal\n - Implications: Low albumin/globulin ratio may indicate liver disease or malnutrition.\n2. Bilirubin: 0.4 mg/dL (Reference Range: 0.3-1.9)\n - Abnormal: No\n - Implications: Normal bilirubin levels.\n3. Globulin: 2.9 g/dL (Reference Range: 1.1-2.5)\n - Abnormal: Yes, above normal\n - Implications: Elevated globulin levels may indicate inflammation or a possible infection.\n4. Hemoglobin: 14.4 g/dL (Reference Range: 12.5-17.0)\n - Abnormal: No\n - Implications: Normal hemoglobin levels.\n5. Protein Total: 6.9 g/dL (Reference Range: 6.0-8.3)\n - Abnormal: No\n - Implications: Normal total protein levels.\n6. Triglyceride: 184.0 mg/dL (Reference Range: &amp;amp;lt;150)\n - Abnormal: Yes, above normal\n - Implications: Elevated triglyceride levels may increase the risk of heart disease.\nComments:\nThe abnormal results in albumin/globulin ratio, globulin, and triglycerides may warrant further evaluation and monitoring. Given that the patient is a non-smoker, these abnormalities are unlikely to be directly related to smoking. I recommend follow-up tests and consultation with a healthcare provider to address these findings and determine the appropriate management plan.\nSigned,\n[Your Name]\nClinical Pathologist | \nThe abnormal results in albumin/globulin ratio, globulin, and triglycerides may warrant further evaluation and monitoring. Given that the patient is a non-smoker, these abnormalities are unlikely to be directly related to smoking. I recommend follow-up tests and consultation with a healthcare provider to address these findings and determine the appropriate management plan.\nSigned,\n[Your Name]\nClinical Pathologist | abnormal albumin globulin globulin triglyceride non smoker abnormality unlikely smoking address appropriate plan signed | 0.042360 | 0.306510 | -0.295259 | -0.268618 | -0.025003 | -0.092417 | -0.18920 | -0.394539 | 0.190384 | -0.232716 | 0.412178 | 0.236966 | -0.118906 | 0.431177 | -0.882996 | -0.651944 | 0.117811 | 0.597496 | 0.244894 | 0.132634 | -0.060487 | -0.017128 | -0.516223 | 0.231145 | 0.269696 | 0.351858 | 0.210463 | 0.487795 | 0.287985 | -0.915484 | -0.119279 | -0.327392 | -0.170393 | 0.584479 | 0.031557 | -0.029035 | ... | 0.715343 | 0.022848 | -0.150271 | -0.345719 | -1.084149 | -0.414737 | -0.306212 | 0.609411 | 0.457310 | 0.407171 | -0.239654 | 0.415814 | 0.249222 | -0.372817 | -0.342280 | 0.028020 | -1.087441 | -0.425647 | 0.342700 | -0.064307 | 0.614285 | 0.617858 | 0.347413 | 0.692275 | -0.201418 | -0.941788 | 0.123500 | -0.559311 | -0.320557 | 0.147939 | -0.002603 | -0.455726 | -0.506842 | 0.827802 | -0.366837 | -0.402616 | 0.165527 | -0.419158 | -0.173898 | 0.157974 |
2 rows × 772 columns
TF-IDF¶
# Q1a Step 2
# Perform TF-IDF vectorisation
# Check for NaN values
pathology_df['tfidf_commentary'].isna().sum().sum()
# Initialise the TF-IDF vectoriser
# Set max features to 1000 to balance between capturing sufficient significant
# words and preventing dimensionality from going out-of-hand. This also helps to
# prevent the model from being overfitted.
tfidf_vectorizer = TfidfVectorizer(max_features = 1000)
# Fit and transform the "tfidf_commentary" column
# Contain the TF-IDF results in a DataFrame
tfidf_df = pd.DataFrame(tfidf_vectorizer.fit_transform(pathology_df['tfidf_commentary']).toarray(), columns = tfidf_vectorizer.get_feature_names_out())
# Perform PCA to see if number of TF-IDF vectors can be reduced
# Increases effectiveness of clustering later on as well
tfidf_df_scaled = StandardScaler().fit_transform(tfidf_df)
pca = PCA()
pca.fit(tfidf_df_scaled)
# Visualise explained cumulative variance ratio
# 1. Plot shows that 500 features can already explain roughly 70% of total
# variance.
# 2. Re-do TF-IDF vectorisation by setting max features to 500 as halving the
# max features with only a moderate sacrifice in total variance explained is a
# reasonable and sensible trade-off to make. This is especially because the
# clustering done later on will be much more effective and meaningful on a lower
# dimensionality dataset.
plt.figure(figsize = (8, 5))
plt.step(range(1, len(pca.explained_variance_ratio_) + 1), np.cumsum(pca.explained_variance_ratio_))
plt.xlabel('Number of Principal Components')
plt.ylabel('Cumulative Explained Variance Ratio')
plt.title('Plot of Cumulative Explained Variance Ratio')
plt.show()
# Initialise the TF-IDF vectoriser
# Set max features to 500
tfidf_revised_vectorizer = TfidfVectorizer(max_features = 500)
# Fit and transform the "tfidf_commentary" column
tfidf_revised_df = pd.DataFrame(tfidf_revised_vectorizer.fit_transform(pathology_df['tfidf_commentary']).toarray(), columns = tfidf_revised_vectorizer.get_feature_names_out())
# Check that the number of rows of the TF-IDF vectors is the same as the number
# of rows in pathology_df
if len(tfidf_revised_df) == len(pathology_df):
print("Number of rows of TF-IDF vectors is the same as in pathology_df.")
else:
print("Number of rows of TF-IDF vectors is not the same as in pathology_df.")
# Check for NaN values in TF-IDF vectors
print("Total NaN values in TF-IDF vectors:", tfidf_revised_df.isna().sum().sum())
# Check for 0 values in TF-IDF vectors
print("Total zero values in TF-IDF vectors:", (tfidf_revised_df == 0).sum().sum())
# Concatenate the TF-IDF DataFrame with the original data
pathology_df = pd.concat([pathology_df, tfidf_revised_df], axis = 1)
# The highest word frequency in the 'tfidf_commentary' column seen earlier in
# Q1a was smoking.
# Randomly sample three written reports with the word "smoking" in them and
# retrieve their corresponding "smoking" TF-IDF value.
rows = pathology_df[pathology_df['written_report'].str.contains("smoking", case = False, na = False)].sample(n = 3, random_state = random_seed)
for index, row in rows.iterrows():
written_report = row['written_report']
smoking_value = row['smoking']
print("Transcript:", written_report)
print("\nSmoking value:", smoking_value)
print("-" * 450)
# Set column width to max
pd.set_option('display.max_colwidth', None)
# Calculate cosine similarity between all TF-IDF vectors
tfidf_revised_similarity_matrix = cosine_similarity(tfidf_revised_df)
# Find the two rows with the highest similarity excluding self-similarity
np.fill_diagonal(tfidf_revised_similarity_matrix, -1)
most_similar_indices = np.unravel_index(np.argmax(tfidf_revised_similarity_matrix), tfidf_revised_similarity_matrix.shape)
# Display the two most similar rows
print("Two most similar rows based on TF-IDF")
display(pathology_df.iloc[[most_similar_indices[0], most_similar_indices[1]]])
# Find the two rows with the lowest similarity excluding self-similarity
np.fill_diagonal(tfidf_revised_similarity_matrix, 1)
most_distant_indices = np.unravel_index(np.argmin(tfidf_revised_similarity_matrix), tfidf_revised_similarity_matrix.shape)
# Display the two most distant rows
print("Two most distant rows based on TF-IDF:")
display(pathology_df.iloc[[most_distant_indices[0], most_distant_indices[1]]])
# Histogram of cosine similarity between all TF-IDF vectors
plt.figure(figsize = (8, 5))
plt.hist(tfidf_revised_similarity_matrix[tfidf_revised_similarity_matrix != 0], bins = 20, edgecolor = 'black')
plt.xlabel("Cosine Similarity")
plt.ylabel("Frequency")
plt.title("Distribution of Cosine Similarity based on TF-IDF (excluding self-similarity)")
plt.grid()
plt.show()
Number of rows of TF-IDF vectors is the same as in pathology_df. Total NaN values in TF-IDF vectors: 0 Total zero values in TF-IDF vectors: 6374256 Transcript: **Pathology Report:** **Patient Information:** - Name: Mary Johnson - Date of Birth: December 27, 1973 - Smoking Status: 1-2 packs per day (current smoker) **Abnormal Results:** 1. **Total Protein:** 6.9 g/dL (Reference Range: 6.0-8.3 g/dL, Alert Low) 2. **Triglycerides:** 261.0 mg/dL (Reference Range: 150-200 mg/dL, Panic Low) **Comments:** The blood test results for Mary Johnson show abnormalities in total protein and triglyceride levels. The total protein level is low, indicating potential malnutrition or liver disease. The triglyceride level is significantly elevated, which can increase the risk of cardiovascular disease. Given Mary's smoking status as a current smoker, this could further exacerbate the risk of cardiovascular disease. Smoking is a major risk factor for heart disease, and the combination of elevated triglycerides and smoking puts Mary at a higher risk. **Recommendations:** 1. Further evaluation of liver function and nutritional status may be warranted to determine the cause of the low total protein levels. 2. Lifestyle changes, including smoking cessation and dietary modifications, should be considered to reduce the risk of cardiovascular disease. Additional follow-up with a healthcare provider is recommended to address the abnormal results and develop a personalized plan for monitoring and managing Mary's overall health. Smoking value: 0.16290185938703658 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Transcript: **Pathology Report:** **Patient Information:** - Name: Steven Wood - Date of Birth: July 27, 1965 - Smoking Status: Non-smoker (0 cigarettes per day) **Abnormal Results:** 1. **Bilirubin:** - Result: 0.6 mg/dL (Reference Range: 0.2-1.2 mg/dL) - Implication: The bilirubin level is within the normal range, indicating normal liver function. **Comments:** Based on the blood test results, Steven Wood's bilirubin level is within the normal range, suggesting normal liver function. Given that Steven is a non-smoker, there are no significant abnormalities to report. No further action or treatment is necessary at this time. --- Dr. [Your Name] Clinical Pathologist Smoking value: 0.0 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Transcript: Pathology Report: Patient: Andrea Elkins Date of Birth: June 28, 1939 Smoking Status: 0 cigarettes per day (previous smoker) Abnormal Results: 1. Hematocrit: 40.0% (Above Normal High) 2. Hemoglobin: 13.3 g/dL (Above Normal High) Comments: The blood test results for Andrea Elkins show elevated levels of hematocrit and hemoglobin, which are both above the normal range. These findings may indicate a condition such as dehydration, polycythemia, or underlying heart or lung issues. Considering that Andrea is a previous smoker, these abnormal results could potentially be related to her smoking history. Recommendations: 1. Further evaluation is recommended to determine the underlying cause of the elevated hematocrit and hemoglobin levels. This may involve additional diagnostic tests and consultation with a healthcare provider. 2. Since Andrea is a previous smoker, it is important to discuss any potential health risks associated with her smoking history and make appropriate lifestyle changes to improve her overall health. Smoking value: 0.11984624263435369 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Two most similar rows based on TF-IDF
| LabResultGuid | written_report | redacted_commentary | tfidf_commentary | embedding_0 | embedding_1 | embedding_2 | embedding_3 | embedding_4 | embedding_5 | embedding_6 | embedding_7 | embedding_8 | embedding_9 | embedding_10 | embedding_11 | embedding_12 | embedding_13 | embedding_14 | embedding_15 | embedding_16 | embedding_17 | embedding_18 | embedding_19 | embedding_20 | embedding_21 | embedding_22 | embedding_23 | embedding_24 | embedding_25 | embedding_26 | embedding_27 | embedding_28 | embedding_29 | embedding_30 | embedding_31 | embedding_32 | embedding_33 | embedding_34 | embedding_35 | ... | tailored | term | thank | thorough | thrombocytopenia | thyroid | thyroxine | time | tobacco | track | tract | treatment | triglyceride | troponin | tsh | type | ul | undergo | underlying | understand | unlikely | upper | ups | uric | urinalysis | urinary | urine | urobilinogen | use | various | vera | vitamin | vldl | warranted | weakness | weight | white | work | worth | young | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 3385 | fa2f17a0-38c6-422f-9107-0d65113af7bd | ### Pathology Report:\n**Patient Information:**\n- Name: Juanita Garcia\n- Date of Birth: February 17, 1950\n- Smoking Status: Up to 1 pack per day\n---\n#### Abnormal Results:\n1. **Bilirubin:**\n - Result: 0.4 mg/dL\n - Reference Range: 0.2-1 mg/dL\n - Abnormal: No\n2. **Uric Acid, Serum:**\n - Result: 9.8 ml/min/1.73m2\n - Reference Range: 2.5-7.7 ml/min/1.73m2\n - Abnormal: Yes\n---\n### Comments:\nThe blood test results for Juanita Garcia show an elevated level of Uric Acid in the serum, which is above the normal reference range. This abnormal result can indicate underlying conditions such as gout or kidney disease. Given the patient's smoking status, it is important to note that smoking can also contribute to elevated uric acid levels, potentially exacerbating the risk of gout or kidney dysfunction.\n### Recommendations:\n1. Further evaluation and monitoring of kidney function are warranted due to the elevated uric acid levels.\n2. Encourage Juanita to consider smoking cessation to help manage her uric acid levels and overall health. | \nThe blood test results for {PATIENT} show an elevated level of Uric Acid in the serum, which is above the normal reference range. This abnormal result can indicate underlying conditions such as gout or kidney disease. Given the patient's smoking status, it is important to note that smoking can also contribute to elevated uric acid levels, potentially exacerbating the risk of gout or kidney dysfunction.\n### Recommendations:\n1. Further evaluation and monitoring of kidney function are warranted due to the elevated uric acid levels.\n2. {PATIENT} to consider smoking cessation to help manage her uric acid levels and overall health. | elevated uric acid serum above normal abnormal underlying gout kidney disease smoking note smoking contribute elevated uric acid potentially exacerbating risk gout kidney dysfunction kidney warranted elevated uric acid smoking cessation help manage uric acid | -0.076355 | 0.350946 | 0.353220 | -0.352238 | 0.402397 | -0.679165 | 0.957638 | -0.711091 | -0.899945 | 0.121093 | 0.132902 | 0.353399 | 0.348123 | -0.539971 | -0.536249 | -0.096731 | -0.095326 | 0.450154 | 0.088411 | 0.616706 | -0.448070 | 0.121512 | -0.641079 | 0.278021 | -0.359558 | 0.211616 | 0.273447 | 0.000121 | -0.232263 | 0.124817 | 0.313549 | -0.823237 | -0.295116 | 0.341600 | -0.235911 | -0.182944 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.041464 | 0.0 | 0.0 | 0.0 | 0.0 | 0.717799 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.085401 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 7155 | 20eecf0f-2b80-404a-9a34-6b9ac1b53bf8 | **Pathology Report**\n**Patient Information:**\n- Name: Janet Brown\n- Date of Birth: September 3, 1972\n- Smoking Status: 0 cigarettes per day (previous smoker)\n---\n**Abnormal Results:**\n1. **Uric Acid, Serum:** \n - Observation Value: 268.0 mg/dL\n - Reference Range: Cutoff=25\n - Abnormal: Yes\n---\n**Comments:**\nThe uric acid level in Janet Brown's serum is significantly elevated at 268.0 mg/dL, which is above the normal reference range of Cutoff=25 mg/dL. Elevated uric acid levels can be indicative of conditions such as gout, kidney disease, and metabolic disorders.\nGiven that Janet Brown is a previous smoker, it is essential to consider lifestyle factors that may contribute to elevated uric acid levels. It is recommended that further evaluation be conducted to determine the underlying cause of the high uric acid levels and appropriate management strategies be implemented.\n--- | **\nThe uric acid level in {PATIENT} serum is significantly elevated at 268.0 mg/dL, which is above the normal reference range of Cutoff=25 mg/dL. Elevated uric acid levels can be indicative of conditions such as gout, kidney disease, and metabolic disorders.\nGiven that {PATIENT} is a previous smoker, it is essential to consider lifestyle factors that may contribute to elevated uric acid levels. It is recommended that further evaluation be conducted to determine the underlying cause of the high uric acid levels and appropriate management strategies be implemented.\n--- | uric acid serum significantly elevated above normal cutoff elevated uric acid indicative gout kidney disease metabolic disorder smoker essential contribute elevated uric acid conducted underlying high uric acid appropriate strategy implemented | -0.074648 | 0.011957 | 0.274452 | -0.552526 | 0.468873 | -0.705488 | 1.035491 | -0.430184 | -0.714982 | 0.738251 | 0.132844 | 0.411108 | 0.231654 | -0.688081 | -0.421136 | -0.116273 | 0.061279 | 0.138338 | 0.049577 | 0.250031 | -0.261141 | 0.096047 | -0.286078 | -0.060955 | -0.595730 | 0.249801 | 0.568583 | -0.208826 | -0.348316 | -0.095841 | 0.256388 | -0.832045 | -0.427881 | 0.100973 | -0.169239 | -0.143823 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.042088 | 0.0 | 0.0 | 0.0 | 0.0 | 0.728597 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 rows × 1272 columns
Two most distant rows based on TF-IDF:
| LabResultGuid | written_report | redacted_commentary | tfidf_commentary | embedding_0 | embedding_1 | embedding_2 | embedding_3 | embedding_4 | embedding_5 | embedding_6 | embedding_7 | embedding_8 | embedding_9 | embedding_10 | embedding_11 | embedding_12 | embedding_13 | embedding_14 | embedding_15 | embedding_16 | embedding_17 | embedding_18 | embedding_19 | embedding_20 | embedding_21 | embedding_22 | embedding_23 | embedding_24 | embedding_25 | embedding_26 | embedding_27 | embedding_28 | embedding_29 | embedding_30 | embedding_31 | embedding_32 | embedding_33 | embedding_34 | embedding_35 | ... | tailored | term | thank | thorough | thrombocytopenia | thyroid | thyroxine | time | tobacco | track | tract | treatment | triglyceride | troponin | tsh | type | ul | undergo | underlying | understand | unlikely | upper | ups | uric | urinalysis | urinary | urine | urobilinogen | use | various | vera | vitamin | vldl | warranted | weakness | weight | white | work | worth | young | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7ee243d0-a158-46fa-8d0f-1ca401de2b77 | **Pathology Report**\n**Patient Information:**\n- Name: Gabriel Warwick\n- Date of Birth: March 27, 1994\n- Smoking Status: 0 cigarettes per day (previous smoker)\n---\n**Blood Test Results:**\n1. Hematocrit: 38.3% (Reference Range: 36.0-50.0)\n - Result: Within normal range\n2. Immature Granulocytes: 74.2 g/dL\n - Result: Abnormal, further investigation may be required\n3. Ketones: Not Available\n - Result: Not available for analysis\n4. Total Protein: 1.8 g/dL (Reference Range: 0.0-1.2)\n - Result: Abnormal, higher than the reference range\n5. Urobilinogen: 9.2 mg/dL (Reference Range: Yellow)\n - Result: Within normal range\n---\n**Comments:**\n- The Immature Granulocytes level of 74.2 g/dL is abnormal and may indicate an ongoing infection or inflammation. Further evaluation and monitoring may be needed.\n- The Total Protein level of 1.8 g/dL is higher than the reference range. This could be due to various factors such as dehydration, liver disease, or inflammation. Additional testing and clinical correlation are recommended.\n- The patient's previous smoking status may have contributed to the abnormal blood test results. It is important to address smoking cessation and its potential impact on overall health.\n---\n**Recommendations:**\n1. Follow-up testing to monitor the Immature Granulocytes levels and investigate possible underlying conditions.\n2. Further evaluation of Total Protein levels to determine the cause of the abnormal result.\n3. Encourage lifestyle modifications, including smoking cessation, to improve overall health and potentially normalize future blood test results. \nPlease consult with the patient and consider a referral to a healthcare provider for additional evaluation and management if required. | **\n- The Immature Granulocytes level of 74.2 g/dL is abnormal and may indicate an ongoing infection or inflammation. Further evaluation and monitoring may be needed.\n- The Total Protein level of 1.8 g/dL is higher than the reference range. This could be due to various factors such as dehydration, liver disease, or inflammation. Additional testing and clinical correlation are recommended.\n- The patient's previous smoking status may have contributed to the abnormal blood test results. It is important to address smoking cessation and its potential impact on overall health.\n---\n**Recommendations:**\n1. Follow-up testing to monitor the Immature Granulocytes levels and investigate possible underlying conditions.\n2. Further evaluation of Total Protein levels to determine the cause of the abnormal result.\n3. Encourage lifestyle modifications, including smoking cessation, to improve overall health and potentially normalize future blood test results. \nPlease consult with the patient and consider a referral to a healthcare provider for additional evaluation and management if required. | immature granulocyte g abnormal ongoing infection inflammation protein g higher various dehydration liver disease inflammation correlation smoking contributed abnormal address smoking cessation potential immature granulocyte investigate possible underlying protein abnormal encourage smoking cessation improve potentially normalize future referral | -0.407162 | -0.337454 | 0.517488 | -0.540889 | -0.139277 | 0.299269 | -0.110034 | -0.173800 | -0.068978 | 0.446372 | 0.156135 | 0.630184 | 0.151318 | -0.0807 | -0.547392 | -0.241517 | 0.752692 | 0.218819 | 0.160436 | -0.028408 | 0.192546 | -0.165921 | -0.900564 | 0.058228 | -0.513585 | -0.340527 | -0.011229 | -0.378836 | 0.005880 | -0.159155 | 0.148313 | -0.780023 | -0.307246 | 0.209648 | 0.172657 | -0.533083 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.08318 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.148389 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 53 | 138702ef-7a31-4b56-80d7-4de60732b95a | **Pathology Report:**\n**Patient Information:**\n- Name: Florine Parra\n- Date of Birth: August 20, 1973\n- Smoking Status: Non-smoker or less than 100 cigarettes in a lifetime\n**Laboratory Test Results:**\n1. Albumin / Globulin Ratio: 1.6 K/uL (Normal)\n2. Cholesterol / HDL Ratio: 0.5% (Normal)\n3. Chloride, Serum: 5.4 x10E3/uL (Normal)\n4. Nitrite: Not Available\n5. pH: 7.1 mmol/L (Normal)\n**Comments:**\nAll the reported laboratory test results are within normal ranges for the patient Florine Parra. It is important to note that the Nitrite result is not available, so further investigation may be required to obtain this information.\n**Recommendation:**\nGiven that all the results are normal, no specific medical interventions are required at this time. It is advisable for the patient to continue with regular health check-ups and maintain a healthy lifestyle, including a balanced diet and regular exercise routine. If there are any concerns or symptoms, the patient should consult with their healthcare provider for further evaluation. | **\nAll the reported laboratory test results are within normal ranges for the patient {PATIENT}. It is important to note that the Nitrite result is not available, so further investigation may be required to obtain this information.\n**Recommendation:**\nGiven that all the results are normal, no specific medical interventions are required at this time. It is advisable for the patient to continue with regular health check-ups and maintain a healthy lifestyle, including a balanced diet and regular exercise routine. If there are any concerns or symptoms, the patient should consult with their healthcare provider for further evaluation. | reported laboratory normal note nitrite obtain information normal no specific medical intervention time advisable continue regular check ups maintain healthy balanced diet regular exercise routine concern symptom | 0.306081 | -0.189231 | 0.752114 | -0.179573 | 0.326037 | -0.212856 | 0.461022 | -0.304251 | -0.622976 | 0.813993 | 0.366816 | 0.183550 | 0.350794 | -0.2646 | -0.128494 | 0.923634 | 0.300509 | -0.435127 | 0.088184 | -0.243197 | -0.149535 | 0.045458 | -0.241739 | -0.129985 | -0.292143 | 0.091898 | 0.336769 | 0.540997 | -0.039589 | 0.097746 | -0.240209 | -0.292988 | -0.688818 | -0.398333 | -0.286231 | 0.362058 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.187019 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.185107 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 rows × 1272 columns
Q1b - Clustering¶
Q1b Step 1 - PCA on Embedding Features¶
# Q1b Step 1
# Place embeddings in separate DataFrame
embeddings_df = pathology_df[[col for col in pathology_df.columns if col.startswith("embedding_")]]
# Check how many embeddings
print(embeddings_df.shape)
# Standardise the data
embeddings_df_scaled = StandardScaler().fit_transform(embeddings_df)
# Perform PCA
embeddings_pca = PCA(n_components = 50)
embeddings_pca_result = embeddings_pca.fit_transform(embeddings_df_scaled)
# Store PCA results in DataFrame
embeddings_pca_df = pd.DataFrame(data = embeddings_pca_result, columns = [f"embeddings_pca_{i}" for i in range(embeddings_pca.n_components_)])
# Checks for missing or infinite values and confirm all columns are numeric
embeddings_pca_errors = {
"has_NaN": embeddings_pca_df.isnull().values.any(),
"has_inf": np.isinf(embeddings_pca_df.values).any(),
"all_numeric": np.all([np.issubdtype(dtype, np.number) for dtype in embeddings_pca_df.dtypes])
}
print("Embeddings PCA Checks:")
for key, value in embeddings_pca_errors.items():
print(f"{key}: {value}")
# Visualise explained variance ratios
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(embeddings_pca.explained_variance_ratio_) + 1), np.cumsum(embeddings_pca.explained_variance_ratio_), marker = 'o')
plt.xlabel("Number of Principal Components")
plt.ylabel("Cumulative Explained Variance Ratio")
plt.title("Plot of Cumulative Explained Variance Ratio (Embeddings)")
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(range(1, len(embeddings_pca.explained_variance_ratio_) + 1), embeddings_pca.explained_variance_ratio_, marker = 'o')
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("Scree Plot (i.e. Explained Variance Ratio per PC) (Embeddings)")
plt.grid()
plt.tight_layout()
plt.show()
# Print explained variance ratios
print(f"Cumulative explained variance ratio from first 15 PCs (embeddings): {np.cumsum(embeddings_pca.explained_variance_ratio_)[:15][-1]:.4f}")
# Choose first 15 PCs from embeddings PCA as they explain roughly 60% of total
# variance collectively. Generally, we would prefer to see around 70% of total
# variance explained, but in this context, we will still be getting some
# explanation of the total variance from the PCs of the TF-IDF PCA, so there is
# sufficient reason to adjust the benchmark down slightly to 60%.
embeddings_pca_revised_df = embeddings_pca_df.iloc[:, :15]
(13489, 768) Embeddings PCA Checks: has_NaN: False has_inf: False all_numeric: True
Cumulative explained variance ratio from first 15 PCs (embeddings): 0.5945
Q1b Step 2 - PCA on TF-IDF Features¶
# Q1b Step 2
# Standardise the data
tfidf_revised_df_scaled = StandardScaler().fit_transform(tfidf_revised_df)
# Perform PCA
tfidf_pca = PCA(n_components = 50)
tfidf_pca_result = tfidf_pca.fit_transform(tfidf_revised_df_scaled)
# Store PCA results in DataFrame
tfidf_pca_df = pd.DataFrame(data = tfidf_pca_result, columns = [f"tfidf_pca_{i}" for i in range(tfidf_pca.n_components_)])
# Checks for missing or infinite values and confirm all columns are numeric
tfidf_pca_errors = {
"has_NaN": tfidf_pca_df.isnull().values.any(),
"has_inf": np.isinf(tfidf_pca_df.values).any(),
"all_numeric": np.all([np.issubdtype(dtype, np.number) for dtype in tfidf_pca_df.dtypes])
}
print("TF-IDF PCA Checks:")
for key, value in tfidf_pca_errors.items():
print(f"{key}: {value}")
# Visualise explained variance ratios
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(tfidf_pca.explained_variance_ratio_) + 1), np.cumsum(tfidf_pca.explained_variance_ratio_), marker = 'o')
plt.xlabel("Number of Principal Components")
plt.ylabel("Cumulative Explained Variance Ratio")
plt.title("Plot of Cumulative Explained Variance Ratio (TF-IDF)")
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(range(1, len(tfidf_pca.explained_variance_ratio_) + 1), tfidf_pca.explained_variance_ratio_, marker = 'o')
plt.xlabel("Principal Component")
plt.ylabel("Explained Variance Ratio")
plt.title("Scree Plot (i.e. Explained Variance Ratio per PC) (TF-IDF)")
plt.grid()
plt.tight_layout()
plt.show()
# Print explained variance ratios
print(f"Cumulative explained variance ratio from first 10 PCs (TF-IDF): {np.cumsum(tfidf_pca.explained_variance_ratio_)[:10][-1]:.4f}")
# Choose first 10 PCs from TF-IDF PCA as they explain roughly 10% of total
# variance collectively. When combined with the 15 PCs chosen from the
# embeddings PCA, we have roughly 70% of total variance explained, which is
# reasonable.
tfidf_pca_revised_df = tfidf_pca_df.iloc[:, :10]
TF-IDF PCA Checks: has_NaN: False has_inf: False all_numeric: True
Cumulative explained variance ratio from first 10 PCs (TF-IDF): 0.0900
Q1b Step 3 - Agglomerative Clustering¶
# Q1b Step 3
# Combine selected PCs for clustering
combined_pca_df = pd.concat([embeddings_pca_revised_df, tfidf_pca_revised_df], axis = 1)
# Function to calculate Dunn Index
def dunn_index(X, labels):
n_clusters = len(np.unique(labels))
# Avoid division by zero for a single cluster
if n_clusters == 1:
return 0
distances = pairwise_distances(X)
cluster_indices = [np.where(labels == i)[0] for i in np.unique(labels)]
min_inter_cluster_distance = np.inf
for i in range(n_clusters):
for j in range(i + 1, n_clusters):
min_inter_cluster_distance = min(min_inter_cluster_distance, distances[cluster_indices[i], :][:, cluster_indices[j]].min())
max_intra_cluster_distance = 0
for i in range(n_clusters):
max_intra_cluster_distance = max(max_intra_cluster_distance, distances[cluster_indices[i], :][:, cluster_indices[i]].max())
return min_inter_cluster_distance / max_intra_cluster_distance
# Choosing distance method and linkage method
# Set number of clusters to 5 to test distance method and linkage method
distance_measures = ['euclidean', 'manhattan', 'cosine']
linkage_methods = ['ward', 'complete', 'average', 'single']
distance_linkage_test_results_df = []
for linkage in linkage_methods:
for distance in distance_measures:
try:
# Use distance measure only when linkage method is not ward as model
# only accepts Euclidean distance measure when linkage method is
# ward
if linkage == 'ward' and distance != 'euclidean':
continue
if linkage == 'ward':
agg_clustering = AgglomerativeClustering(n_clusters = 5, linkage = linkage)
else:
agg_clustering = AgglomerativeClustering(n_clusters = 5, linkage = linkage, metric = distance)
labels = agg_clustering.fit_predict(combined_pca_df)
db_score = davies_bouldin_score(combined_pca_df, labels)
dunn_score = dunn_index(combined_pca_df, labels)
silhouette_avg = silhouette_score(combined_pca_df, labels)
ch_score = calinski_harabasz_score(combined_pca_df, labels)
distance_linkage_test_results_df.append([distance, linkage, db_score, dunn_score, silhouette_avg, ch_score])
except ValueError as e:
print(f"Error with distance = {distance}, linkage = {linkage}: {e}")
distance_linkage_test_results_df.append([distance, linkage, np.nan, np.nan, np.nan, np.nan])
distance_linkage_test_results_df = pd.DataFrame(distance_linkage_test_results_df, columns = ['Distance', 'Linkage', 'Davies-Bouldin', 'Dunn Index', 'Silhouette Score', 'Calinski-Harabasz'])
display(distance_linkage_test_results_df)
| Distance | Linkage | Davies-Bouldin | Dunn Index | Silhouette Score | Calinski-Harabasz | |
|---|---|---|---|---|---|---|
| 0 | euclidean | ward | 2.514606 | 0.109482 | 0.094221 | 958.597162 |
| 1 | euclidean | complete | 3.401520 | 0.116441 | 0.028450 | 542.314655 |
| 2 | manhattan | complete | 3.348373 | 0.107429 | 0.019623 | 537.881190 |
| 3 | cosine | complete | 3.690378 | 0.100329 | 0.033235 | 655.558891 |
| 4 | euclidean | average | 2.163915 | 0.120086 | 0.086648 | 727.831331 |
| 5 | manhattan | average | 1.550701 | 0.182545 | -0.000670 | 20.512399 |
| 6 | cosine | average | 2.417753 | 0.127366 | 0.100138 | 1017.444612 |
| 7 | euclidean | single | 0.742350 | 0.349567 | 0.031867 | 1.885978 |
| 8 | manhattan | single | 0.816860 | 0.323038 | 0.010591 | 1.603929 |
| 9 | cosine | single | 1.133902 | 0.235153 | -0.182817 | 1.089374 |
# Q1b Step 3
# Choosing number of clusters
# Test cluster counts from 3 to 10
cluster_count_test_results = []
for n_clusters in range(3, 11):
agg_clustering = AgglomerativeClustering(n_clusters = n_clusters, metric = 'euclidean', linkage = 'ward')
agg_clustering.fit(combined_pca_df)
labels = agg_clustering.labels_
db_score = davies_bouldin_score(combined_pca_df, labels)
dunn_score = dunn_index(combined_pca_df, labels)
silhouette_avg = silhouette_score(combined_pca_df, labels)
ch_score = calinski_harabasz_score(combined_pca_df, labels)
cluster_count_test_results.append([n_clusters, db_score, dunn_score, silhouette_avg, ch_score])
# Place results in a DataFrame
cluster_count_test_results_df = pd.DataFrame(cluster_count_test_results, columns = ['Clusters', 'Davies-Bouldin', 'Dunn Index', 'Silhouette Score', 'Calinski-Harabasz'])
# Display plots of the internal validation metrics
plt.figure(figsize = (12, 6))
plt.subplot(2, 2, 1)
plt.plot(cluster_count_test_results_df['Clusters'], cluster_count_test_results_df['Davies-Bouldin'], marker = 'o', color = 'blue')
plt.xlabel('Number of Clusters')
plt.ylabel('Davies-Bouldin Index')
plt.title('Davies-Bouldin Index vs. Number of Clusters')
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(cluster_count_test_results_df['Clusters'], cluster_count_test_results_df['Dunn Index'], marker = 'o', color = 'black')
plt.xlabel('Number of Clusters')
plt.ylabel('Dunn Index')
plt.title('Dunn Index vs. Number of Clusters')
plt.grid()
plt.subplot(2, 2, 3)
plt.plot(cluster_count_test_results_df['Clusters'], cluster_count_test_results_df['Silhouette Score'], marker = 'o', color = 'green')
plt.xlabel('Number of Clusters')
plt.ylabel('Silhouette Score')
plt.title('Silhouette Score vs. Number of Clusters')
plt.grid()
plt.subplot(2, 2, 4)
plt.plot(cluster_count_test_results_df['Clusters'], cluster_count_test_results_df['Calinski-Harabasz'], marker = 'o', color = 'red')
plt.xlabel('Number of Clusters')
plt.ylabel('Calinski-Harabasz Index')
plt.title('Calinski-Harabasz Index vs. Number of Clusters')
plt.grid()
plt.tight_layout()
plt.show()
# Find the recommended number of clusters based on each metric
recommended_clusters = {}
# Davies-Bouldin: Minimise the score
min_db_index = cluster_count_test_results_df['Davies-Bouldin'].min()
recommended_clusters['Davies-Bouldin'] = cluster_count_test_results_df[cluster_count_test_results_df['Davies-Bouldin'] == min_db_index]['Clusters'].iloc[0]
# Dunn Index: Maximise the score
max_dunn_index = cluster_count_test_results_df['Dunn Index'].max()
recommended_clusters['Dunn Index'] = cluster_count_test_results_df[cluster_count_test_results_df['Dunn Index'] == max_dunn_index]['Clusters'].iloc[0]
# Silhouette Score: Maximise the score
max_silhouette = cluster_count_test_results_df['Silhouette Score'].max()
recommended_clusters['Silhouette Score'] = cluster_count_test_results_df[cluster_count_test_results_df['Silhouette Score'] == max_silhouette]['Clusters'].iloc[0]
# Calinski-Harabasz: Maximise the score
max_ch_score = cluster_count_test_results_df['Calinski-Harabasz'].max()
recommended_clusters['Calinski-Harabasz'] = cluster_count_test_results_df[cluster_count_test_results_df['Calinski-Harabasz'] == max_ch_score]['Clusters'].iloc[0]
# Print the recommended number of clusters from each metric
# Dunn Index and CH scores indicate using 3 clusters
# Silhouette score indicates using 6 clusters
# DB score indicates using 8 clusters
for metric, n_clusters in recommended_clusters.items():
print(f"Recommended number of clusters based on {metric}: {n_clusters}")
# Perform agglomerative clustering on 3 clusters
agg_clustering_three = AgglomerativeClustering(n_clusters = 3, metric = 'euclidean', linkage = 'ward')
agg_clustering_three.fit(combined_pca_df)
cluster_labels_three = agg_clustering_three.labels_
# Perform agglomerative clustering on 6 clusters
agg_clustering_six = AgglomerativeClustering(n_clusters = 6, metric = 'euclidean', linkage = 'ward')
agg_clustering_six.fit(combined_pca_df)
cluster_labels_six = agg_clustering_six.labels_
# Perform agglomerative clustering on 8 clusters
agg_clustering_eight = AgglomerativeClustering(n_clusters = 8, metric = 'euclidean', linkage = 'ward')
agg_clustering_eight.fit(combined_pca_df)
cluster_labels_eight = agg_clustering_eight.labels_
# Apply t-SNE to reduce dimensions to two for easier visualisation
tsne = TSNE(n_components = 2, random_state = random_seed)
tsne_result = tsne.fit_transform(combined_pca_df)
# Compute and print internal validation metrics for 3 agglomerative clusters
db_score_three = davies_bouldin_score(combined_pca_df, cluster_labels_three)
dunn_score_three = dunn_index(combined_pca_df, cluster_labels_three)
silhouette_avg_three = silhouette_score(combined_pca_df, cluster_labels_three)
ch_score_three = calinski_harabasz_score(combined_pca_df, cluster_labels_three)
print(f"Internal validation metrics for three agglomerative clusters:")
print(f"Davies-Bouldin Index : {db_score_three:.4f}")
print(f"Dunn Index : {dunn_score_three:.4f}")
print(f"Silhouette Score : {silhouette_avg_three:.4f}")
print(f"Calinski-Harabasz Score : {ch_score_three:.2f}")
# Create silhouette plot with 3 agglomerative clusters
sample_silhouette_values_three = silhouette_samples(combined_pca_df, cluster_labels_three)
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
y_lower = 10
for i in range(3):
ith_cluster_silhouette_values = sample_silhouette_values_three[cluster_labels_three == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
color = cm.nipy_spectral(float(i) / 3)
plt.fill_betweenx(np.arange(y_lower, y_upper),
0, ith_cluster_silhouette_values,
facecolor = color, edgecolor = color, alpha = 0.7)
plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
y_lower = y_upper + 10
plt.title("Silhouette plot for three agglomerative clusters")
plt.xlabel("Silhouette Coefficient")
plt.ylabel("Cluster Label")
# Include average silhouette score of all the values
plt.axvline(x = silhouette_avg_three, color = 'red', linestyle = '--')
plt.text(silhouette_avg_three * 1.1, plt.ylim()[1] * 0.05, f'Avg silhouette score:\n {silhouette_avg_three:.4f}', color = 'black', ha = 'left')
plt.yticks([])
plt.xticks([-0.2, 0, 0.2, 0.4, 0.6])
# Create 2D t-SNE plot with 3 agglomerative clusters
plt.subplot(1, 2, 2)
for label in np.unique(cluster_labels_three):
plt.scatter(tsne_result[cluster_labels_three == label, 0], tsne_result[cluster_labels_three == label, 1], label = f'Cluster {label}')
plt.title('2D t-SNE plot for three agglomerative clusters')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend()
plt.tight_layout()
plt.show()
# Compute and print internal validation metrics for 6 agglomerative clusters
db_score_six = davies_bouldin_score(combined_pca_df, cluster_labels_six)
dunn_score_six = dunn_index(combined_pca_df, cluster_labels_six)
silhouette_avg_six = silhouette_score(combined_pca_df, cluster_labels_six)
ch_score_six = calinski_harabasz_score(combined_pca_df, cluster_labels_six)
print(f"Internal validation metrics for six agglomerative clusters:")
print(f"Davies-Bouldin Index : {db_score_six:.4f}")
print(f"Dunn Index : {dunn_score_six:.4f}")
print(f"Silhouette Score : {silhouette_avg_six:.4f}")
print(f"Calinski-Harabasz Score : {ch_score_six:.2f}")
# Create silhouette plot with 6 agglomerative clusters
sample_silhouette_values_six = silhouette_samples(combined_pca_df, cluster_labels_six)
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
y_lower = 10
for i in range(6):
ith_cluster_silhouette_values = sample_silhouette_values_six[cluster_labels_six == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
color = cm.nipy_spectral(float(i) / 6)
plt.fill_betweenx(np.arange(y_lower, y_upper),
0, ith_cluster_silhouette_values,
facecolor = color, edgecolor = color, alpha = 0.7)
plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
y_lower = y_upper + 10
plt.title("Silhouette plot for six agglomerative clusters")
plt.xlabel("Silhouette Coefficient")
plt.ylabel("Cluster Label")
# Include average silhouette score of all the values
plt.axvline(x = silhouette_avg_six, color = 'red', linestyle = '--')
plt.text(silhouette_avg_six * 1.1, plt.ylim()[1] * 0.01, f'Avg silhouette score:\n {silhouette_avg_six:.4f}', color = 'black', ha = 'left')
plt.yticks([])
plt.xticks([-0.2, 0, 0.2, 0.4, 0.6])
# Create 2D t-SNE plot with 6 agglomerative clusters
plt.subplot(1, 2, 2)
for label in np.unique(cluster_labels_six):
plt.scatter(tsne_result[cluster_labels_six == label, 0], tsne_result[cluster_labels_six == label, 1], label = f'Cluster {label}')
plt.title('2D t-SNE plot for six agglomerative clusters')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend()
plt.tight_layout()
plt.show()
# Compute and print internal validation metrics for 8 agglomerative clusters
db_score_eight = davies_bouldin_score(combined_pca_df, cluster_labels_eight)
dunn_score_eight = dunn_index(combined_pca_df, cluster_labels_eight)
silhouette_avg_eight = silhouette_score(combined_pca_df, cluster_labels_eight)
ch_score_eight = calinski_harabasz_score(combined_pca_df, cluster_labels_eight)
print(f"Internal validation metrics for eight agglomerative clusters:")
print(f"Davies-Bouldin Index : {db_score_eight:.4f}")
print(f"Dunn Index : {dunn_score_eight:.4f}")
print(f"Silhouette Score : {silhouette_avg_eight:.4f}")
print(f"Calinski-Harabasz Score : {ch_score_eight:.2f}")
# Create silhouette plot with 8 agglomerative clusters
sample_silhouette_values_eight = silhouette_samples(combined_pca_df, cluster_labels_eight)
plt.figure(figsize = (12, 6))
plt.subplot(1, 2, 1)
y_lower = 10
for i in range(8):
ith_cluster_silhouette_values = sample_silhouette_values_eight[cluster_labels_eight == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
color = cm.nipy_spectral(float(i) / 8)
plt.fill_betweenx(np.arange(y_lower, y_upper),
0, ith_cluster_silhouette_values,
facecolor = color, edgecolor = color, alpha = 0.7)
plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
y_lower = y_upper + 10
plt.title("Silhouette plot for eight agglomerative clusters")
plt.xlabel("Silhouette Coefficient")
plt.ylabel("Cluster Label")
# Include average silhouette score of all the values
plt.axvline(x = silhouette_avg_eight, color = 'red', linestyle = '--')
plt.text(silhouette_avg_eight * 1.1, plt.ylim()[1] * 0.05, f'Avg silhouette score:\n {silhouette_avg_eight:.4f}', color = 'black', ha = 'left')
plt.yticks([])
plt.xticks([-0.2, 0, 0.2, 0.4, 0.6])
# Create 2D t-SNE plot with 8 agglomerative clusters
plt.subplot(1, 2, 2)
for label in np.unique(cluster_labels_eight):
plt.scatter(tsne_result[cluster_labels_eight == label, 0], tsne_result[cluster_labels_eight == label, 1], label = f'Cluster {label}')
plt.title('2D t-SNE plot for eight agglomerative clusters')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend()
plt.tight_layout()
plt.show()
Recommended number of clusters based on Davies-Bouldin: 8 Recommended number of clusters based on Dunn Index: 3 Recommended number of clusters based on Silhouette Score: 6 Recommended number of clusters based on Calinski-Harabasz: 3 Internal validation metrics for three agglomerative clusters: Davies-Bouldin Index : 2.6962 Dunn Index : 0.1394 Silhouette Score : 0.0895 Calinski-Harabasz Score : 1066.58
Internal validation metrics for six agglomerative clusters: Davies-Bouldin Index : 2.4095 Dunn Index : 0.1277 Silhouette Score : 0.1068 Calinski-Harabasz Score : 981.43
Internal validation metrics for eight agglomerative clusters: Davies-Bouldin Index : 2.2422 Dunn Index : 0.1277 Silhouette Score : 0.0981 Calinski-Harabasz Score : 887.02
Q1c - Discriminator modelling¶
Q1c Step 1 - Random Forest¶
# Q1c Step 1
# Prepare the data
cluster_labels = agg_clustering_six.fit_predict(combined_pca_df)
discriminator_df = tfidf_revised_df.assign(cluster = cluster_labels)
discriminator_df['cluster'] = 'cluster_' + discriminator_df['cluster'].astype(str)
display(discriminator_df.head())
X = discriminator_df.drop('cluster', axis = 1)
Y = discriminator_df['cluster']
# Initialise and fit the random forest model
rf_model = RandomForestClassifier(random_state = random_seed)
rf_model.fit(X, Y)
| abbrev | ability | abnormal | abnormality | abnormally | above | absence | absolute | accordingly | accurate | accurately | acid | acidosis | action | activity | additionally | address | addressed | addressing | adequate | adjustment | adopt | adopting | adverse | advice | advisable | advise | advised | advising | affect | affecting | age | albumin | alcohol | alkaline | alkalosis | allergic | allergy | alt | ambig | ... | term | thank | thorough | thrombocytopenia | thyroid | thyroxine | time | tobacco | track | tract | treatment | triglyceride | troponin | tsh | type | ul | undergo | underlying | understand | unlikely | upper | ups | uric | urinalysis | urinary | urine | urobilinogen | use | various | vera | vitamin | vldl | warranted | weakness | weight | white | work | worth | young | cluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.0 | 0.000000 | 0.248112 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.133393 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.083180 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.148389 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | cluster_1 |
| 1 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.268809 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.166351 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | cluster_0 |
| 2 | 0.0 | 0.322954 | 0.085524 | 0.0 | 0.283533 | 0.139798 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.191577 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.16223 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.172034 | 0.287586 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.177162 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | cluster_3 |
| 3 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.186827 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.35721 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | cluster_2 |
| 4 | 0.0 | 0.000000 | 0.071979 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.15042 | 0.0 | 0.0 | 0.0 | 0.0 | 0.00000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | cluster_0 |
5 rows × 501 columns
RandomForestClassifier(random_state=41387)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(random_state=41387)
Q1c Step 2 - Variable Importance (Top 10 Keywords)¶
# Q1c Step 2
# Store feature importance in a DataFrame
feature_importance_df = pd.DataFrame({'feature': X.columns, 'importance': rf_model.feature_importances_})
# Display the top 10 most important features
print("Top 10 most important features:")
display(feature_importance_df.sort_values(by = 'importance', ascending = False).head(10))
# Plot the top 10 most important features
plt.figure(figsize = (8, 5))
plt.barh(feature_importance_df.sort_values(by = 'importance', ascending = False).head(10)['feature'], feature_importance_df.sort_values(by = 'importance', ascending = False).head(10)['importance'])
plt.title('Top 10 Most Important Features')
plt.xlabel('Feature Importance')
plt.ylabel('Feature')
plt.gca().invert_yaxis()
plt.grid(axis = 'x')
plt.show()
Top 10 most important features:
| feature | importance | |
|---|---|---|
| 69 | bilirubin | 0.060944 |
| 32 | albumin | 0.051588 |
| 206 | globulin | 0.033499 |
| 287 | liver | 0.028766 |
| 478 | underlying | 0.022726 |
| 89 | chloride | 0.022059 |
| 472 | triglyceride | 0.018504 |
| 172 | elevated | 0.015947 |
| 330 | normal | 0.013106 |
| 79 | cardiovascular | 0.012760 |
Q1c Step 3 - Examination of Top 10 Keywords and Implications on Clusters¶
# Q1c Step 3
# Calculate the proportion of rows in each cluster that have a non-zero value
# for each of the top 10 features
top_10_features = feature_importance_df.sort_values(by = 'importance', ascending = False).head(10)['feature'].tolist()
cluster_proportions_df = pd.DataFrame(index = top_10_features)
for cluster in discriminator_df['cluster'].unique():
cluster_data = discriminator_df[discriminator_df['cluster'] == cluster]
proportions = []
for feature in top_10_features:
proportions.append(cluster_data[feature].astype(bool).sum() / len(cluster_data))
cluster_proportions_df[cluster] = proportions
display(cluster_proportions_df.reindex(columns = sorted(cluster_proportions_df.columns)).style.background_gradient())
# Analyse how good the top 10 features are at predicting each cluster
# Display the feature importance from most to least important
# Display the partial dependence plots
X_top_10_features = X[top_10_features]
# Check for NaN values in top 10 features
print("Total NaN values in top 10 features:", X_top_10_features.isna().sum().sum())
print()
# For cluster 0
# Initialise and fit the random forest model
rf_model_0 = RandomForestClassifier(random_state = random_seed)
rf_model_0.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_0')
# Display feature importance
feature_importance_0 = pd.Series(rf_model_0.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_0:")
display(feature_importance_0)
# Calculate and display partial dependence plots
percentiles = (0.01, 0.99)
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_0, X_top_10_features, features = feature_importance_0.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
# For cluster 1
# Initialise and fit the random forest model
rf_model_1 = RandomForestClassifier(random_state = random_seed)
rf_model_1.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_1')
# Display feature importance
feature_importance_1 = pd.Series(rf_model_1.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_1:")
display(feature_importance_1)
# Calculate and display partial dependence plots
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_1, X_top_10_features, features = feature_importance_1.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
# For cluster 2
# Initialise and fit the random forest model
rf_model_2 = RandomForestClassifier(random_state = random_seed)
rf_model_2.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_2')
# Display feature importance
feature_importance_2 = pd.Series(rf_model_2.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_2:")
display(feature_importance_2)
# Calculate and display partial dependence plots
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_2, X_top_10_features, features = feature_importance_2.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
# For cluster 3
# Initialise and fit the random forest model
rf_model_3 = RandomForestClassifier(random_state = random_seed)
rf_model_3.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_3')
# Display feature importance
feature_importance_3 = pd.Series(rf_model_3.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_3:")
display(feature_importance_3)
# Calculate and display partial dependence plots
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_3, X_top_10_features, features = feature_importance_3.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
# For cluster 4
# Initialise and fit the random forest model
rf_model_4 = RandomForestClassifier(random_state = random_seed)
rf_model_4.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_4')
# Display feature importance
feature_importance_4 = pd.Series(rf_model_4.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_4:")
display(feature_importance_4)
# Calculate and display partial dependence plots
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_4, X_top_10_features, features = feature_importance_4.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
# For cluster 5
# Initialise and fit the random forest model
rf_model_5 = RandomForestClassifier(random_state = random_seed)
rf_model_5.fit(X_top_10_features, discriminator_df['cluster'] == 'cluster_5')
# Display feature importance
feature_importance_5 = pd.Series(rf_model_5.feature_importances_, index = X_top_10_features.columns).sort_values(ascending = False)
print("Feature importance for cluster_5:")
display(feature_importance_5)
# Calculate and display partial dependence plots
plt.figure(figsize = (8, 12))
PartialDependenceDisplay.from_estimator(rf_model_5, X_top_10_features, features = feature_importance_5.index.tolist(), percentiles = percentiles, ax = plt.gca())
plt.tight_layout()
plt.show()
| cluster_0 | cluster_1 | cluster_2 | cluster_3 | cluster_4 | cluster_5 | |
|---|---|---|---|---|---|---|
| bilirubin | 0.142617 | 0.056916 | 0.360243 | 0.921128 | 0.074819 | 0.040864 |
| albumin | 0.072569 | 0.028617 | 0.849277 | 0.226577 | 0.056315 | 0.037779 |
| globulin | 0.120762 | 0.191097 | 0.853010 | 0.331740 | 0.239743 | 0.079414 |
| liver | 0.171196 | 0.207949 | 0.614559 | 0.761472 | 0.176991 | 0.117965 |
| underlying | 0.399832 | 0.865501 | 0.472235 | 0.503824 | 0.793242 | 0.360062 |
| chloride | 0.081255 | 0.059141 | 0.172189 | 0.159178 | 0.643604 | 0.025443 |
| triglyceride | 0.156907 | 0.173609 | 0.189454 | 0.162046 | 0.177796 | 0.788743 |
| elevated | 0.350799 | 0.716375 | 0.400373 | 0.526769 | 0.629123 | 0.690054 |
| normal | 0.585598 | 0.554213 | 0.672888 | 0.704111 | 0.457763 | 0.477255 |
| cardiovascular | 0.154105 | 0.163752 | 0.126458 | 0.118069 | 0.136766 | 0.765613 |
Total NaN values in top 10 features: 0 Feature importance for cluster_0:
| 0 | |
|---|---|
| underlying | 0.219920 |
| normal | 0.171817 |
| elevated | 0.162336 |
| bilirubin | 0.074986 |
| liver | 0.074970 |
| triglyceride | 0.073895 |
| globulin | 0.073139 |
| cardiovascular | 0.067177 |
| chloride | 0.041320 |
| albumin | 0.040440 |
Feature importance for cluster_1:
| 0 | |
|---|---|
| underlying | 0.297355 |
| elevated | 0.161232 |
| normal | 0.138125 |
| bilirubin | 0.072422 |
| liver | 0.067188 |
| chloride | 0.063655 |
| triglyceride | 0.055452 |
| albumin | 0.053771 |
| cardiovascular | 0.048447 |
| globulin | 0.042354 |
Feature importance for cluster_2:
| 0 | |
|---|---|
| albumin | 0.376588 |
| globulin | 0.186853 |
| normal | 0.091885 |
| liver | 0.082546 |
| bilirubin | 0.082535 |
| underlying | 0.058016 |
| elevated | 0.052546 |
| chloride | 0.025228 |
| triglyceride | 0.024292 |
| cardiovascular | 0.019510 |
Feature importance for cluster_3:
| 0 | |
|---|---|
| bilirubin | 0.434692 |
| liver | 0.168264 |
| normal | 0.095851 |
| underlying | 0.061391 |
| elevated | 0.059566 |
| albumin | 0.059180 |
| globulin | 0.052149 |
| chloride | 0.026986 |
| triglyceride | 0.025357 |
| cardiovascular | 0.016562 |
Feature importance for cluster_4:
| 0 | |
|---|---|
| chloride | 0.330133 |
| underlying | 0.191739 |
| elevated | 0.124110 |
| normal | 0.114135 |
| bilirubin | 0.065134 |
| liver | 0.044613 |
| globulin | 0.039479 |
| triglyceride | 0.038301 |
| cardiovascular | 0.029287 |
| albumin | 0.023069 |
Feature importance for cluster_5:
| 0 | |
|---|---|
| triglyceride | 0.283398 |
| cardiovascular | 0.274849 |
| elevated | 0.109829 |
| normal | 0.101235 |
| underlying | 0.088727 |
| liver | 0.040053 |
| bilirubin | 0.038120 |
| globulin | 0.030108 |
| chloride | 0.022444 |
| albumin | 0.011236 |
Q1d - Manual validation¶
Q1d Step 1 - Word Clouds¶
# Q1d Step 1
# Combine pathology df with cluster labels
manual_val_df = pathology_df.assign(cluster = cluster_labels)
# Generate word clouds for each cluster
for cluster in sorted(manual_val_df['cluster'].unique()):
text = ' '.join(manual_val_df[manual_val_df['cluster'] == cluster]['tfidf_commentary'].astype(str))
wordcloud = WordCloud(width = 750, height = 400, background_color = 'white').generate(text)
plt.figure(figsize = (10, 5))
plt.imshow(wordcloud, interpolation = 'bilinear')
plt.axis('off')
plt.title(f'Word Cloud for Cluster {cluster}')
plt.show()
if cluster < 5:
print()
Q1d Step 2 - Word Counts¶
# Q1d Step 2
# Generate most common words by cluster
top_words_per_cluster = {}
for cluster, texts in manual_val_df.groupby('cluster')['tfidf_commentary']:
all_words = ' '.join(texts).split()
word_counts = Counter(all_words)
top_words = [word for word, _ in word_counts.most_common(10)]
top_words_per_cluster[cluster] = top_words
print("Top 10 words by cluster:\n")
for cluster, words in sorted(top_words_per_cluster.items()):
print(f"Cluster {cluster}: {', '.join(words)}\n")
# Calculate the average, mininum, and maximum word counts per report for each
# cluster
cluster_statistics = []
for i in sorted(manual_val_df['cluster'].unique()):
cluster_df = manual_val_df[manual_val_df['cluster'] == i]
word_counts = cluster_df['tfidf_commentary'].str.split().str.len()
mean_words = round(word_counts.mean(), 2) if not word_counts.empty else 0
min_words = round(word_counts.min(), 2) if not word_counts.empty else 0
max_words = round(word_counts.max(), 2) if not word_counts.empty else 0
cluster_statistics.append({'Cluster': i, 'Mean': mean_words, 'Minimum': min_words, 'Maximum': max_words})
print("-" * 110)
print()
print("Average, minimum, and maximum words of the report commentaries in each cluster:")
display(pd.DataFrame(cluster_statistics))
Top 10 words by cluster: Cluster 0: smoking, normal, abnormal, elevated, regular, potential, smoker, underlying, no, abnormality Cluster 1: smoking, elevated, underlying, abnormal, normal, protein, potential, abnormality, low, infection Cluster 2: smoking, liver, globulin, normal, albumin, abnormal, elevated, potential, underlying, kidney Cluster 3: liver, smoking, bilirubin, normal, elevated, abnormal, underlying, potential, smoker, globulin Cluster 4: smoking, abnormal, elevated, underlying, chloride, abnormality, potassium, low, normal, potential Cluster 5: triglyceride, smoking, cardiovascular, risk, elevated, abnormal, disease, regular, diet, cholesterol -------------------------------------------------------------------------------------------------------------- Average, minimum, and maximum words of the report commentaries in each cluster:
| Cluster | Mean | Minimum | Maximum | |
|---|---|---|---|---|
| 0 | 0 | 33.47 | 6 | 112 |
| 1 | 1 | 40.54 | 9 | 102 |
| 2 | 2 | 39.22 | 9 | 121 |
| 3 | 3 | 36.69 | 10 | 103 |
| 4 | 4 | 39.57 | 10 | 141 |
| 5 | 5 | 39.90 | 11 | 95 |
Q1d Step 3 - Sentiments¶
# Q1d Step 3
# Initialise the sentiment analysis pipeline
sentiment_classifier = pipeline("sentiment-analysis")
# Function to get sentiment
def get_sentiment(text):
try:
result = sentiment_classifier(text)[0]
return result['label']
except Exception as e:
return "Error"
# Apply the function to the 'tfidf_commentary' column
manual_val_df['sentiment'] = manual_val_df['tfidf_commentary'].apply(get_sentiment)
# Plot sentiment proportions by cluster
sentiment_counts = manual_val_df.groupby(['cluster', 'sentiment']).size().unstack(fill_value = 0)
sentiment_proportions = sentiment_counts.div(sentiment_counts.sum(axis = 1), axis = 0)
plt.figure(figsize = (10, 8))
sentiment_proportions.plot(kind = 'bar', stacked = True)
plt.title('Sentiment Proportions per Cluster')
plt.xlabel('Cluster')
plt.ylabel('Proportion')
plt.xticks(rotation = 0)
plt.legend(title = 'Sentiment', loc = 'upper right')
plt.show()
No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english). Using a pipeline without specifying a model name and revision in production is not recommended.
config.json: 0%| | 0.00/629 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/268M [00:00<?, ?B/s]
tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]
vocab.txt: 0.00B [00:00, ?B/s]
Device set to use cuda:0 You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
<Figure size 1000x800 with 0 Axes>
Q1d Step 4 - Centroids¶
# Q1d Step 4
# Function to find row in each cluster that has the embedding vector closest to
# the average embedding vector for that cluster
def closest_to_centroid(pathology_df):
cluster_centroids = {}
closest_rows = {}
for cluster in sorted(manual_val_df['cluster'].unique()):
cluster_data = manual_val_df[manual_val_df['cluster'] == cluster]
centroid = cluster_data[[col for col in manual_val_df.columns if col.startswith("embedding_")]].mean()
cluster_centroids[cluster] = centroid
min_distance = float('inf')
closest_row_index = -1
for index, row in cluster_data.iterrows():
embedding_vector = row[[col for col in manual_val_df.columns if col.startswith("embedding_")]].values
distance = np.linalg.norm(embedding_vector - centroid.values)
if distance < min_distance:
min_distance = distance
closest_row_index = index
closest_rows[cluster] = closest_row_index
return closest_rows
# Show the written_report from the row in each cluster that has the embedding
# vector closest to the average embedding vector for that cluster
for cluster, index in closest_to_centroid(manual_val_df).items():
print(f"Cluster {cluster}:\n")
print(manual_val_df.loc[index, 'written_report'])
if cluster < 5:
print()
print("-" * 520)
print()
Cluster 0: **Pathology Report:** The blood test results for Linda Ramos, a 21-year-old female patient, have been reviewed. The following abnormal results were noted: 1. **Bilirubin:** The result is 0.3 mg/dL, which is below the normal low range (97-108 mg/dL). This may indicate a potential issue with liver function. 2. **Triglyceride:** The result is 123.0 mg/dL, which is above the normal high range (1.8-7.8 mg/dL). Elevated triglyceride levels are associated with an increased risk of cardiovascular disease. 3. **C-Reactive Protein:** The result is 202.0 mg/dL, exceeding the normal range of 10.0-20.0 mg/dL. Elevated levels may indicate inflammation in the body. 4. **Protein Total:** The result is not provided; further evaluation is needed to assess protein levels. 5. **Ambig Abbrev CMP14 Default:** The result is 7.6 g/dL, above the normal high range. This may indicate abnormal protein levels in the blood. **Comments:** Considering Linda Ramos' age and the absence of smoking status information, it is important to note that the abnormal results, particularly the elevated triglyceride and C-reactive protein levels, could pose a risk to her cardiovascular health. Further evaluation and lifestyle modifications, such as dietary changes and increased physical activity, may be recommended to reduce the risk of future health complications. Follow-up testing and consultation with a healthcare provider are advised to address the abnormal results and develop a personalized treatment plan. ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Cluster 1: Pathology Report: Patient: Wanda Brenton Date of Birth: December 5, 2003 Smoking Status: NA Abnormal Results: 1. Hematocrit: 41.4% (Reference Range: 36.0-50.0%), Above Normal High 2. Hemoglobin: 55.4 g/dL (Reference Range: 12.5-17.0) 3. Platelet Count: 298.0 x10E3/uL (Reference Range: 1.5-4.5) 4. Potassium, Serum: 3.3 mmol/L (Reference Range: 3.5-5.2), Below Normal Low 5. Protein Total: 6.1 g/dL (Reference Range: 6.0-8.6), Above Normal High Comments: The blood test results for Wanda Brenton show several abnormalities that require attention. The elevated hematocrit and hemoglobin levels suggest a possible issue with dehydration or other underlying conditions that need to be investigated further. The significantly high platelet count also requires additional evaluation to rule out any potential hematologic disorder. The low potassium level in the serum can lead to various complications, including muscle weakness and cardiac arrhythmias. It is important to address this electrolyte imbalance promptly. Additionally, the elevated total protein level may indicate inflammation, infections, or other systemic disorders that need to be explored. Recommendations: 1. Further investigation into the elevated hematocrit, hemoglobin, and platelet count to determine the underlying cause. 2. Address the low potassium level promptly to prevent potential complications. 3. Evaluate the elevated total protein level to identify any underlying inflammatory or systemic conditions. Given the patient's young age and lack of smoking history, the abnormalities in the blood test results are concerning and warrant prompt attention and thorough evaluation by a healthcare provider. ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Cluster 2: **Pathology Report** **Patient Information:** - Name: Bertie Lee - Date of Birth: September 3, 1933 - Smoking Status: 0 cigarettes per day (previous smoker) --- **Laboratory Abnormal Results:** 1. **Albumin / Globulin Ratio:** - Result: 1.7 (Reference Range: 6.0-8.5) - Abnormal: Yes - Implications: The albumin/globulin ratio is below the normal range, indicating possible liver or kidney diseases. 2. **Hematocrit:** - Result: 39.4% (Reference Range: 41-50) - Abnormal: No 3. **Protein Total:** - Result: 6.8 g/dL (Reference Range: 6.4-8.3) - Abnormal: No 4. **Potassium, Serum:** - Result: 4.8 mmol/L (Reference Range: 3.5-5.2) - Abnormal: No --- **Comments:** - The albumin/globulin ratio is below the normal range, which may indicate liver or kidney issues. Further evaluation and follow-up are recommended to determine the underlying cause. It's important to consider Bertie's smoking history as smoking can contribute to liver and kidney diseases. An evaluation by a specialist may be necessary to assess the need for additional testing or treatment options. This report provides an overview of Bertie Lee's blood test results. Further consultation with a healthcare provider is advised for a comprehensive evaluation and management plan. ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Cluster 3: **Pathology Report:** **Patient Information:** - Name: Marisa King - Date of Birth: November 28, 1986 - Smoking Status: Non-smoker **Laboratory Results:** 1. **Albumin / Globulin Ratio:** Not available 2. **Bilirubin:** 0.3 mg/dL (Reference Range: 0.2-1.2 mg/dL) 3. **Chloride, Serum:** 104.0 mmol/L (Reference Range: 98-107 mmol/L) 4. **Globulin:** 2.5 g/dL (Reference Range: 2.0-4.0 g/dL) 5. **Neutrophils:** Not available 6. **RDW:** Not available 7. **Platelet Count:** 2.6 x10E3/uL (Reference Range: 1.5-4.5 x10E3/uL) **Comments:** - The results of the blood tests for Marisa King show a slightly elevated bilirubin level at 0.3 mg/dL, which is within the normal range but worth monitoring in the future. - The chloride level is slightly elevated at 104.0 mmol/L, which could be influenced by various factors including diet or hydration status. Further evaluation may be needed if this persists. - The platelet count is within the normal range at 2.6 x10E3/uL. Platelet counts can be affected by smoking, but since Ms. King is a non-smoker, this result is not likely related to smoking status. - The albumin / globulin ratio and neutrophil count are not available for review at this time. **Recommendations:** - Follow-up testing may be warranted to monitor the bilirubin and chloride levels for any changes over time. - Considering Ms. King's non-smoking status, the platelet count result is not concerning at this time. - Additional testing may be needed to obtain the missing data points for a comprehensive evaluation of the blood test results. ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Cluster 4: **Pathology Report** **Patient Information:** - Name: Mallory Keeley - Date of Birth: January 11, 1962 - Smoking Status: Not available --- **Abnormal Results:** 1. **Chloride, Serum:** 91.0 mmol/L (Reference Range: 3.5-5.2 mmol/L) - Below Normal Low 2. **Hemoglobin:** 14.8 g/dL (Reference Range: 12.5-17.0 g/dL) - Above Normal High 3. **Protein Total:** 6.4 g/dL (Reference Range: 0.0-1.2 g/dL) - Above Normal High 4. **Nitrite:** Above Normal High --- **Comments:** The blood test results for Mallory Keeley show several abnormalities that may need further investigation. The low serum chloride levels, along with the high hemoglobin and total protein levels, could indicate underlying health issues. The elevated nitrite levels may also suggest an infection or inflammation. Since the smoking status of the patient is not available, it is important to note that smoking can impact certain blood test results, such as hemoglobin levels. If Mallory Keeley is a current smoker, this information should be considered when interpreting the results. It is recommended that the patient consult with a healthcare provider for a comprehensive evaluation and appropriate management based on these findings. --- *This report is based on the information provided and should be interpreted in conjunction with clinical findings and patient history.* ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Cluster 5: **Pathology Report:** **Patient Information:** - **Name:** Veronica Parker - **Date of Birth:** September 29, 1974 - **Smoking Status:** Not available --- **Lab Results:** 1. **Thyroxine:** 0.4 K/uL (Normal) 2. **Absolute Lymph:** Not available 3. **Urinalysis Reflex:** Not available 4. **Triglyceride:** 109.0 mg/dL - **Reference Range:** 1.8-7.8 mg/dL - **Abnormal Flags:** Above Normal High --- **Comments:** - The triglyceride level for Veronica Parker is significantly elevated at 109.0 mg/dL, well above the normal reference range of 1.8-7.8 mg/dL. This high level may indicate an increased risk of cardiovascular disease, particularly if combined with other risk factors. Considering the patient's date of birth, it is essential to address this abnormal value promptly to mitigate any potential health risks. **Recommendations:** 1. Given the elevated triglyceride level, further evaluation and management are recommended to lower the level and reduce the risk of cardiovascular complications. 2. Lifestyle modifications, such as a healthy diet, regular exercise, and smoking cessation (if applicable), should be advised to improve overall cardiovascular health. --- It is important to follow up with Veronica Parker for additional testing and potentially initiate treatment to address the elevated triglyceride levels. Regular monitoring and lifestyle changes can help manage the risk factors associated with high triglyceride levels.
Q1d Step 5 - Other Feature (Gender)¶
# Q1d Step 5
# Join manual_val_df with labresult_df on LabResultGuid
manual_val_other_feats_df = pd.merge(manual_val_df, labresult_df, on = 'LabResultGuid', how = 'left')
# Join the result with patient_df on PatientGuid to retrieve gender
manual_val_other_feats_df = pd.merge(manual_val_other_feats_df, patient_df[['PatientGuid', 'Gender']].drop_duplicates(), on = 'PatientGuid', how = 'left')
# Check that number of rows in manual_val_df and manual_val_other_feats_df are equal
print(f"Number of rows in manual_val_df: {len(manual_val_df)}")
print(f"Number of rows in manual_val_other_feats_df: {len(manual_val_other_feats_df)}")
# Check for missing values in 'Gender' column
print("Missing values in manual_val_other_feats_df:")
print(manual_val_other_feats_df[['Gender']].isnull().sum())
# Calculate gender proportions by cluster
gender_cluster_counts = manual_val_other_feats_df.groupby(['cluster', 'Gender']).size().unstack(fill_value = 0)
gender_cluster_proportions = gender_cluster_counts.div(gender_cluster_counts.sum(axis = 1), axis = 0)
# Plot the stacked bar chart of gender proportions by cluster
plt.figure(figsize=(10, 6))
gender_cluster_proportions.plot(kind = 'bar', stacked = True)
plt.title('Distribution of Gender by Cluster')
plt.xlabel('Cluster')
plt.ylabel('Proportion')
plt.xticks(rotation = 0)
plt.legend(title = 'Gender', loc = 'upper right')
plt.show()
Number of rows in manual_val_df: 13489 Number of rows in manual_val_other_feats_df: 13489 Missing values in manual_val_other_feats_df: Gender 0 dtype: int64
<Figure size 1000x600 with 0 Axes>
Q1d Step 6 - Other Feature (Age)¶
# Q1d Step 6
# Join the result with patient_df on PatientGuid to retrieve DoB
manual_val_other_feats_df = pd.merge(manual_val_other_feats_df, patient_df[['PatientGuid', 'DateOfBirth']].drop_duplicates(), on = 'PatientGuid', how = 'left')
# Calculate age from DoB
manual_val_other_feats_df['age'] = (pd.to_datetime("now") - pd.to_datetime(manual_val_other_feats_df["DateOfBirth"])).dt.days // 365
# Check that number of rows in manual_val_df and manual_val_other_feats_df are
# equal
print(f"Number of rows in manual_val_df: {len(manual_val_df)}")
print(f"Number of rows in manual_val_other_feats_df: {len(manual_val_other_feats_df)}")
# Check for missing values in 'age' column
print("Missing values in manual_val_other_feats_df:")
print(manual_val_other_feats_df[['age']].isnull().sum())
# Plot a histogram of age distributions for each cluster
for cluster in sorted(manual_val_other_feats_df['cluster'].unique()):
cluster_age_data = manual_val_other_feats_df[manual_val_other_feats_df['cluster'] == cluster].dropna(subset = ['age'])
plt.figure(figsize = (8, 5))
plt.hist(cluster_age_data['age'], bins = 10, edgecolor = 'black')
plt.title(f'Age Distribution for Cluster {cluster}')
plt.xlabel('Age')
plt.ylabel('Frequency')
plt.grid(axis = 'y')
plt.show()
if cluster < 5:
print()
Number of rows in manual_val_df: 13489 Number of rows in manual_val_other_feats_df: 13489 Missing values in manual_val_other_feats_df: age 0 dtype: int64
Q1d Step 7 - Intuitive Cluster Labels¶
Overall summary
The insights gained from the manual validation conducted allows us to intuitively label each cluster so that Betahelf can deploy these clusters for use in their working environment more easily and effectively.
Cluster 0: General Wellness & Smoking Advice (Mixed Findings)
Reasoning. Word clouds and common word analysis (i.e. word counts) show prominent terms like "smoking", "cessation", "normal", and "abnormal". The sentiment analysis indicates an overall negative tone, likely due to the presence of both normal and abnormal findings, with a focus on providing general advice related to smoking cessation and regular check-ups. The representative report also reflects advice for lifestyle modifications to reduce future health complications. This cluster appears to group commentaries that are not tied to a specific disease but rather, focus on general health, risk factors (especially smoking), and recommendations for monitoring, reflecting a broad patient group.
Business Context. Betahelf can use this cluster to identify patients who will benefit from general health guidance, potentially for preventative programs or follow-up on lifestyle factors.
Cluster 1: Smoking-Related Concerns (Elevated/Abnormal Findings)
Reasoning. Similar to Cluster 0, "smoking" and "cessation" are key terms, but the emphasis shifts more towards "elevated" and "abnormal" findings. The sentiment within the cluster is also predominantly negative and the representative report highlights multiple abnormalities from the lab results. This cluster appears to capture commentaries where smoking is discussed alongside abnormal or elevated lab results, suggesting a more direct link between smoking and the health issues being reported.
Business Context. Betahelf can leverage this cluster to target interventions for smokers with existing abnormal findings, potentially requiring more urgent attention tied to their lab results or specific smoking-cessation programs.
Cluster 2: Liver/Kidney Function (General)
Reasoning. Key terms like "albumin", "globulin", and "liver" dominate this cluster's word cloud and common word list. The sentiment is largely negative, reflecting potential health issues. The representative report explicitly discusses an abnormal albumin/globulin ratio and its implication for liver or kidney diseases. This cluster appears to group report commentaries focusing on general liver function tests and related indicators (like albumin and globulin), likely discussing potential liver or kidney issues without necessarily specifying a severe or acute condition.
Business Context. Betahelf can use this to identify patients with potential early signs of liver or kidney issues based on general markers, allowing for proactive monitoring or further diagnostic steps.
Cluster 3: Liver Function (Specific/Severe - Bilirubin Focus)
Reasoning. Similar to cluster 2, this cluster also features "liver"-related terms, but "bilirubin" is a much more prominent term here as highlighted by the discriminator analysis and word cloud. Terms like "dysfunction", "elevated", and "smoking" also appear, suggesting more severe or specific liver issues. The representative report mentions an elevated bilirubin level and discusses related factors as well as the need for monitoring.
Business Context. Betahelf can use this cluster to prioritise patients with specific and potentially more severe liver concerns, enabling targeted interventions and specialist referrals.
Cluster 4: Electrolyte & Chloride Imbalances
Reasoning. Words like "electrolyte", "imbalance", "abnormal", and, specifically, "chloride" are prominent in this cluster. The sentiment is negative, which is expected for abnormal findings. The representative report highlights a below-normal serum chloride level, indicating possible underlying health issues. This cluster uniquely focuses on electrolyte balance, particularly chloride levels, implying that it captures commentaries related to metabolic or kidney-related issues.
Business Context. Betahelf can target patients in this cluster for evaluation of electrolyte imbalances and related conditions, potentially identifying those at risk of kidney or other metabolic problems.
Cluster 5: Cardiovascular Risk
Reasoning. This cluster is clearly defined by terms like "triglyceride", "cardiovascular", "risk", and "lipid". "Smoking" and "cessation" are also significant, linking lifestyle to risk. The representative report explicitly discusses elevated triglycerides as a risk factor for cardiovascular disease, mentioning age as a factor too. This cluster strongly groups commentaries related to cardiovascular health and risk factors, particularly focusing on lipid levels.
Business Context. Betahelf can effectively use this cluster to identify patients at high risk of cardiovascular diseases based on their lab results and associated commentary, facilitating targeted risk management programs, lifestyle interventions, and preventative care.
Q1e - Summarise clusters¶
Q1e Step 1¶
Outline
This step provides a summary, within 500 words, of the cluster analysis we have conducted, including actionable insights for Betahelf and presented in language highly suitable for the Betahelf management team.
Summary
Based on the analysis of the pathology report commentaries, six distinct groups, or clusters, of reports were identified. These clusters were created using a combination of text analysis techniques and validated by examining keywords, sentiment, representative reports, and patient demographics.
The findings indicate that Betahelf's pathologists are indeed providing valuable, context-rich information in their commentaries that goes beyond simply reporting structured lab results. The clusters reveal that the commentaries group around specific medical themes and conditions, highlighting the pathologists' focus on interpreting results within a clinical context. In particular, we found clusters focused on:
Cluster 0: General Wellness & Smoking Advice (Mixed Findings) and Cluster 1: Smoking-Related Concerns (Elevated/Abnormal Findings)
- These clusters contained commentaries offering general guidance, often related to smoking, with variations based on whether findings were mixed or more consistently abnormal.
Cluster 2: Liver/Kidney Function (General) and Cluster 3: Liver Function (Specific/Severe - Bilirubin Focus)
- These clusters contained commentaries discussing liver/kidney health, with Cluster 3 potentially focusing on more specific/severe issues than Cluster 2.
Cluster 4: Electrolyte & Chloride Imbalances
- This cluster contained commentaries specifically addressing electrolyte levels, particularly chloride, and related potential conditions.
Cluster 5: Cardiovascular Risk
- This cluster contained commentaries linking lab results (like triglycerides) to cardiovascular disease risk, often including lifestyle factors like smoking.
The distinct nature of these clusters, particularly those focused on specific medical areas like liver function or cardiovascular risk, demonstrates that the pathologists are adding significant value by providing interpretive context and highlighting potential health implications. This information, which is not readily available in the structured data alone, is crucial for understanding the nuances of a patient's condition.
Actionable Insights
Targeted Patient Programs: The identified clusters can be used to proactively identify patient groups for targeted health programs. For example, patients in the "Cardiovascular Risk" cluster (i.e. Cluster 5) could be enrolled in preventative cardiovascular health programs or lifestyle intervention initiatives focusing on diet and exercise. Similarly, patients in the "Smoking-Related Concerns" cluster (i.e. Cluster 1) could be prioritised for smoking cessation support.
Enhanced Physician Communication: The insights from each cluster can inform communication strategies with physicians, highlighting the key concerns and contextual information provided by pathologists for different patient groups. This can facilitate more focused discussions and treatment planning.
Pathologist Training & Feedback: The analysis of common themes and discriminative keywords within clusters can provide valuable feedback to pathologists, potentially highlighting areas where their commentary is particularly impactful or where further emphasis might be beneficial. Understanding which terms are most indicative of specific clinical concerns can refine reporting practices.
Streamlined Report Review: By categorising reports into these intuitive clusters, Betahelf can potentially streamline the report review process, allowing healthcare professionals to quickly identify reports related to specific areas of concern.
In conclusion, the cluster analysis validates the significant value added by Betahelf's pathologists through their detailed commentaries, providing a rich source of information for proactive healthcare interventions, hence improving patient care and management.
Savepoint¶
pathology_df_cleaned = pathology_df.assign(cluster = cluster_labels)
pathology_df_cleaned.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/pathology_df_cleaned.pkl')
Q2 - Predict acute diagnoses¶
Loadpoint¶
pathology_df_cleaned = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/pathology_df_cleaned.pkl')
Q2a - Clean the dataset¶
Q2a Step 1 - Exploratory Data Analysis¶
# Q2a Step 1
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML, clear_output
import warnings
warnings.filterwarnings('ignore')
def show_table(table_name):
df = assignmentdata[table_name]
display(Markdown(f'### Preview: {table_name}'))
display(df.head(10))
display(Markdown(f'**Shape:** {df.shape}'))
display(Markdown(f'**Columns:** {list(df.columns)}'))
table_explorer = widgets.interactive(show_table, table_name = widgets.Dropdown(options = list(assignmentdata.keys()), description = 'Table:'))
display(table_explorer)
def enhanced_eda_report(table):
df = assignmentdata[table]
display(Markdown(f'### Enhanced EDA Report for {table}'))
# Create comprehensive summary
summary_data = []
for col in df.columns:
series = df[col]
null_count = series.isnull().sum()
null_percentage = (null_count / len(series)) * 100
unique_count = series.nunique()
if pd.isna(unique_count):
unique_count = 0
summary_data.append({
'Column': col,
'Data Type': str(series.dtype),
'Missing Values': int(null_count),
'Missing %': f'{null_percentage:.1f}%',
'Unique Count': int(unique_count),
'Unique %': f'{(unique_count / len(series) * 100):.1f}%'
})
summary_df = pd.DataFrame(summary_data)
# Display the summary table with proper formatting
display(Markdown('#### Column Summary'))
display(summary_df.style.hide(axis = 'index').set_properties(**{
'text-align': 'left',
'border': '1px solid #ddd',
'padding': '8px'
}).set_table_styles([{
'selector': 'th',
'props': [('background-color', '#f2f2f2'), ('font-weight', 'bold'), ('text-align', 'left')]
}]))
# Show detailed statistics for numeric columns
numeric_cols = df.select_dtypes(include = [np.number]).columns
if len(numeric_cols) > 0:
display(Markdown('**Numeric Columns Statistics:**'))
numeric_stats = df[numeric_cols].describe()
display(numeric_stats.style.format('{:.2f}').set_properties(**{
'text-align': 'right',
'border': '1px solid #ddd',
'padding': '8px'
}).set_table_styles([{
'selector': 'th',
'props': [('background-color', '#f2f2f2'), ('font-weight', 'bold'), ('text-align', 'right')]
}]))
# Show datetime columns if any
datetime_cols = df.select_dtypes(include = ['datetime64']).columns
if len(datetime_cols) > 0:
display(Markdown('#### Datetime Columns Range'))
datetime_data = []
for col in datetime_cols:
series = df[col]
datetime_data.append({
'Column': col,
'Min Date': str(series.min()),
'Max Date': str(series.max())
})
datetime_df = pd.DataFrame(datetime_data)
display(datetime_df.style.hide(axis = 'index').set_properties(**{
'text-align': 'left',
'border': '1px solid #ddd',
'padding': '8px'
}).set_table_styles([{
'selector': 'th',
'props': [('background-color', '#f2f2f2'), ('font-weight', 'bold'), ('text-align', 'left')]
}]))
# Create interactive EDA report
eda_widget = widgets.interactive(enhanced_eda_report, table = widgets.Dropdown(options = list(assignmentdata.keys()), description = 'Table:'))
display(eda_widget)
interactive(children=(Dropdown(description='Table:', options=('Patient', 'Diagnosis', 'Visit', 'LabResult', 'L…
interactive(children=(Dropdown(description='Table:', options=('Patient', 'Diagnosis', 'Visit', 'LabResult', 'L…
Q2a Step 2 - Identification of Privacy Issues and Anonymisation of Data¶
# Q2a Step 2
patient_df_cleaned = patient_df.copy()
# Drop [Title], [GivenName], [Surname], [StreetAddress], [City], [ZipCode],
# [Latitude], [Longitude] to increase anonymity
# The encrypted PatientGuid is sufficient to identify patient-level information
# without [Title], [GivenName], and [Surname].
# When geographies are needed, use state level information (i.e. [StateCode]).
# Retaining [StateCode] also makes it possible to join the [StateDetails] table
# later.
patient_df_cleaned.drop(columns = ["Title", "GivenName", "Surname", "StreetAddress", "City", "ZipCode", "Latitude", "Longitude"], inplace = True)
# Display the first few rows to verify that the relevant columns have been
# dropped
display(patient_df_cleaned.head(3))
# For timestamps, they will only be used for modelling purposes. All reporting
# will use monthly granularity (as further explained in Q2b below).
| RowID | ValidFrom | ValidTo | PatientGuid | Gender | DateOfBirth | StateCode | BloodType | |
|---|---|---|---|---|---|---|---|---|
| 0 | 6519cb4e-8bb0-463f-80bd-e903b538ef40 | 2020-01-01 | 2099-12-31 23:59:59 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | M | 1994-03-27 | MI | A+ |
| 1 | 19ffd5b9-a36a-4ce5-a210-847ec1222f78 | 2020-01-01 | 2099-12-31 23:59:59 | 0056cdf7-609c-4c4e-8acc-0aaef6f1998e | M | 1976-07-03 | PA | O+ |
| 2 | 96909230-da10-4f52-8476-8b70e82c3f09 | 2020-01-01 | 2099-12-31 23:59:59 | 00c5e26d-e323-47c2-bfcb-a2e9fe95f86d | F | 1952-04-15 | KS | AB+ |
Q2a Step 3 - Cleaning of ICD-9 Codes¶
# Q2a Step 3
# The current category of ICD-9 codes are too granular. For example, according
# to the ICD9 table, in the '005+' ICD-9 codes (a total of 10 codes) all appear
# to refer to the same Group1 (i.e. Infectious And Parasitic Diseases), Group2
# (i.e. Intestinal Infectious Diseases), and Group3 (i.e. Other food poisioning
# (bacterial)).
# There are also a total of 17553 ICD-9 codes in the ICD9 table. This level of
# granularity is too fine and will only increase the model complexity without
# adding much further significance.
# Check the proportion of ICD-9 codes with only one occurrence
# There is a very high proportion (~30%) of ICD-9 codes with only one occurrence
# in the diagnosis table. The model will not be able to learn much on these
# ICD-9 codes and, hence, they are unlikely to add to the model's predictive
# power.
print(f"Proportion of ICD-9 codes appearing only once: {diagnosis_df['ICD9Code'].value_counts()[diagnosis_df['ICD9Code'].value_counts() == 1].count() / diagnosis_df['ICD9Code'].value_counts().count():.4f}")
# Check the distinct lengths of ICD-9 codes grouped by their first character in
# diagnosis table
print("\nDistinct lengths of ICD-9 codes grouped by their first character in diagnosis table:")
display(diagnosis_df.groupby(diagnosis_df['ICD9Code'].astype(str).str[0]).apply(lambda x: sorted(x['ICD9Code'].astype(str).str.len().unique())))
# Check the distinct lengths of ICD-9 codes grouped by their first character in
# ICD9 table
print("\nDistinct lengths of ICD-9 codes grouped by their first character in ICD9 table:")
display(icd9_df.groupby(icd9_df['ICD9Code'].astype(str).str[0]).apply(lambda x: sorted(x['ICD9Code'].astype(str).str.len().unique())))
# Generally, the main ICD-9 codes are 3 digits with sub-categories occupying the
# subsequent digits (e.g. the 002.1 ICD-9 code belongs to the 002 group). The
# exception are the main ICD-9 codes that start with V, which are 2 digits with
# sub-categories occupying the subsequent digits
# Cut off the last 2 digits for ICD-9 codes that do not start with E. Cut off
# the last digit for ICD-9 codes that start with E.
# Function to clean ICD9Code based on length and starting character
def clean_icd9code(icd9code):
code_len = len(icd9code)
if (code_len == 4 and not icd9code.startswith('E')) or (code_len == 5 and icd9code.startswith('E')):
# Cut last character
return icd9code[:-1]
elif code_len == 5 and not icd9code.startswith('E'):
# Cut last two characters
return icd9code[:-2]
else:
return icd9code
# Apply the function to the diagnosis table and ICD9 table
diagnosis_df_cleaned = diagnosis_df.copy()
diagnosis_df_cleaned['ICD9Code_cleaned'] = diagnosis_df_cleaned['ICD9Code'].apply(clean_icd9code)
icd9_df_cleaned = icd9_df.copy()
icd9_df_cleaned['ICD9Code_cleaned'] = icd9_df_cleaned['ICD9Code'].apply(clean_icd9code)
# Display the first few rows to show the new column
display(diagnosis_df_cleaned.head())
display(icd9_df_cleaned.head())
# Very that length of ICD-9 codes in diagnosis table after cleaning is 3 for
# codes that do not start with E and 4 for codes that start with E
print("\nDistinct length of ICD-9 codes in diagnosis table after cleaning:")
display(diagnosis_df_cleaned.groupby(diagnosis_df_cleaned['ICD9Code_cleaned'].astype(str).str[0]).apply(lambda x: sorted(x['ICD9Code_cleaned'].astype(str).str.len().unique())))
# Very that length of ICD-9 codes in ICD9 table after cleaning is 3 for
# codes that do not start with E and 4 for codes that start with E
print("\nDistinct length of ICD-9 codes in ICD9 table after cleaning:")
display(icd9_df_cleaned.groupby(icd9_df_cleaned['ICD9Code_cleaned'].astype(str).str[0]).apply(lambda x: sorted(x['ICD9Code_cleaned'].astype(str).str.len().unique())))
# Condense ICD9 table
icd9_df_cleaned_v2 = icd9_df_cleaned[['ICD9Code_cleaned', 'Group1', 'Group2', 'Group3']].drop_duplicates().reset_index(drop = True)
# Display the first few rows to confirm output
print("First few rows of icd9_df_cleaned_v2:")
display(icd9_df_cleaned_v2.head())
# Compare the shape of the new ICD9 table to the old ICD9 table
print("\nShape of icd9_df_cleaned:")
display(icd9_df_cleaned.shape)
print("\nShape of icd9_df_cleaned_v2:")
display(icd9_df_cleaned_v2.shape)
# Verify that proportion of ICD-9 codes appearing only once in diagnosis table has dropped
print(f"Proportion of ICD-9 codes appearing only once: {diagnosis_df_cleaned['ICD9Code_cleaned'].value_counts()[diagnosis_df_cleaned['ICD9Code_cleaned'].value_counts() == 1].count() / diagnosis_df_cleaned['ICD9Code_cleaned'].value_counts().count():.4f}")
Proportion of ICD-9 codes appearing only once: 0.2976 Distinct lengths of ICD-9 codes grouped by their first character in diagnosis table:
| 0 | |
|---|---|
| ICD9Code | |
| 0 | [3, 4, 5] |
| 1 | [3, 4, 5] |
| 2 | [3, 4, 5] |
| 3 | [3, 4, 5] |
| 4 | [3, 4, 5] |
| 5 | [3, 4, 5] |
| 6 | [3, 4, 5] |
| 7 | [3, 4, 5] |
| 8 | [3, 4, 5] |
| 9 | [3, 4, 5] |
Distinct lengths of ICD-9 codes grouped by their first character in ICD9 table:
| 0 | |
|---|---|
| ICD9Code | |
| 0 | [3, 4, 5] |
| 1 | [3, 4, 5] |
| 2 | [3, 4, 5] |
| 3 | [3, 4, 5] |
| 4 | [3, 4, 5] |
| 5 | [3, 4, 5] |
| 6 | [3, 4, 5] |
| 7 | [3, 4, 5] |
| 8 | [3, 4, 5] |
| 9 | [3, 4, 5] |
| E | [4, 5] |
| V | [3, 4, 5] |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0001b321-79e1-4d47-8052-9409c8ffa5d1 | dad809ff-b505-46f1-b7f0-37156351c6fd | 2021-03-11 17:23:14 | -06:00 | 4779 | Allergic rhinitis, cause unspecified | False | 477 |
| 1 | 00032606-031b-461f-9aba-b1e3e752da15 | e79b04b1-7d36-4eed-bca8-f956eee7972a | 2022-10-28 21:14:16 | -07:00 | 4619 | Acute sinusitis, unspecified | False | 461 |
| 2 | 000354e4-f147-4017-95b2-bec05486ea9b | 0a6de54f-28d0-4a64-8f0c-bf57a06b77cb | 2023-02-28 20:26:57 | -05:00 | 2167 | Benign neoplasm of skin of lower limb, includi... | False | 216 |
| 3 | 000648fa-f1e9-4faf-a9d3-f03735933c45 | 3ddfb7e6-e55e-4359-9003-f03fdfd1e033 | 2024-12-12 15:26:19 | -05:00 | 2449 | Unspecified hypothyroidism | True | 244 |
| 4 | 00075684-1073-4b2a-ace7-158ac0c8750a | bea2c0d6-9347-4404-9b96-b8e890ac622d | 2023-05-10 15:23:08 | -05:00 | 49390 | Asthma, unspecified type, without mention of s... | True | 493 |
| ICD9Code | Group1 | Group2 | Group3 | ICD9Code_cleaned | |
|---|---|---|---|---|---|
| 0 | 001 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Cholera | 001 |
| 1 | 0010 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Cholera | 001 |
| 2 | 0011 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Cholera | 001 |
| 3 | 0019 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Cholera | 001 |
| 4 | 002 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Typhoid and paratyphoid fevers | 002 |
Distinct length of ICD-9 codes in diagnosis table after cleaning:
| 0 | |
|---|---|
| ICD9Code_cleaned | |
| 0 | [3] |
| 1 | [3] |
| 2 | [3] |
| 3 | [3] |
| 4 | [3] |
| 5 | [3] |
| 6 | [3] |
| 7 | [3] |
| 8 | [3] |
| 9 | [3] |
Distinct length of ICD-9 codes in ICD9 table after cleaning:
| 0 | |
|---|---|
| ICD9Code_cleaned | |
| 0 | [3] |
| 1 | [3] |
| 2 | [3] |
| 3 | [3] |
| 4 | [3] |
| 5 | [3] |
| 6 | [3] |
| 7 | [3] |
| 8 | [3] |
| 9 | [3] |
| E | [4] |
| V | [3] |
First few rows of icd9_df_cleaned_v2:
| ICD9Code_cleaned | Group1 | Group2 | Group3 | |
|---|---|---|---|---|
| 0 | 001 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Cholera |
| 1 | 002 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Typhoid and paratyphoid fevers |
| 2 | 003 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Other salmonella infections |
| 3 | 004 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Shigellosis |
| 4 | 005 | Infectious And Parasitic Diseases | Intestinal Infectious Diseases | Other food poisoning (bacterial) |
Shape of icd9_df_cleaned:
(17553, 5)
Shape of icd9_df_cleaned_v2:
(1234, 4)
Proportion of ICD-9 codes appearing only once: 0.1518
Q2a Step 4 - Cleaning of Height Data in 'Visit' Table¶
# Q2a Step 4
visit_df_cleaned = visit_df.copy()
# From the distribution of heights, height appears to be in inches
print("Range of heights:")
print(f" Height - Min: {visit_df_cleaned['Height'].min():.2f}, Max: {visit_df_cleaned['Height'].max():.2f}")
# Convert height from inches to cm (1 inch = 2.54 cm) for easier analysis
visit_df_cleaned['Height_cm'] = visit_df_cleaned['Height'] * 2.54
# Display the minimum and maximum values for height (in cm)
print("\nRange of heights (converted to metric):")
print(f" Height (cm) - Min: {visit_df_cleaned['Height_cm'].min():.2f}, Max: {visit_df_cleaned['Height_cm'].max():.2f}")
# Manually check heights above 200 cm and clean where necessary (4 rows)
# Observation 32667 appears normal (height slightly above 200 cm)
# The other observations appear to have erroneous heights of over 1000 cm
# The decimal place for these heights appear to be in the wrong place (e.g.
# index 12807 has a height of approximately 1798.32 cm, where 179.832 cm appears
# much more reasonable)
visit_df_cleaned[visit_df_cleaned['Height_cm'] > 200].shape[0]
visit_df_cleaned[visit_df_cleaned['Height_cm'] > 200]
# Solution: Divide (original) height by 10 where (original) height > 100 inches
visit_df_cleaned['Height_cleaned'] = visit_df_cleaned['Height'].apply(lambda x: x / 10 if x > 80 else x)
# Verify that the erroneous heights have been fixed
visit_df_cleaned[visit_df_cleaned['Height_cm'] > 200]
Range of heights: Height - Min: 48.00, Max: 708.00 Range of heights (converted to metric): Height (cm) - Min: 121.92, Max: 1798.32
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | Temperature | PhysicianSpecialty | Height_cm | Height_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 12807 | 3602e745-69de-4b38-aa93-aab2101e15aa | 26585054-b8cf-4c61-b361-fb414397c3d4 | 2020-02-28 14:01:34 | -05:00 | 708.0 | 175.0 | 24.405 | 126.0 | 78.0 | NaN | 99.3 | Family Practice | 1798.320 | 70.80 |
| 32667 | 4f17ebad-efa5-45e1-8267-a35f960bd1ee | 5f20f7d3-2021-43e5-adc7-3a4165346438 | 2020-03-24 16:04:12 | -05:00 | 79.0 | 260.0 | NaN | 130.0 | 88.0 | NaN | NaN | Family Practice | 200.660 | 79.00 |
| 74930 | 18a4d727-630b-4421-b35f-82526df3fb22 | d796bed5-8755-4a05-a4d8-244d468b55fa | 2021-12-18 18:26:03 | -05:00 | 616.5 | 165.0 | 31.831 | 110.0 | 78.0 | 16.0 | NaN | Unknown | 1565.910 | 61.65 |
| 80207 | 9d1b1132-03a2-4844-9ad6-d21dc82041eb | e5de7efc-106f-4be9-a7a9-df0068e3a466 | 2020-11-28 22:17:06 | -08:00 | 595.8 | 241.0 | 39.684 | 124.0 | 82.0 | NaN | 96.6 | Internal Medicine | 1513.332 | 59.58 |
Q2a Step 5 - Cleaning of Systolic and Diastolic BP Data in 'Visit' Table¶
# Q2a Step 5
# Distribution of BP
print("Range of BP:")
print(f" Systolic BP - Min: {visit_df_cleaned['SystolicBP'].min():.2f}, Max: {visit_df_cleaned['SystolicBP'].max():.2f}")
print(f" Diastolic BP - Min: {visit_df_cleaned['DiastolicBP'].min():.2f}, Max: {visit_df_cleaned['DiastolicBP'].max():.2f}")
# Manually check systolic BP below 75 and clean where necessary (4 rows)
# Observation 63961 appears plausible (systolic BP of 44)
# The other observations appear to have erroneous systolic BP of 14 (likely a
# missing 0 at the end as a systolic BP of 140 appears much more reasonable)
visit_df_cleaned[visit_df_cleaned['SystolicBP'] < 75].shape[0]
visit_df_cleaned[visit_df_cleaned['SystolicBP'] < 75]
# Solution: Multiply systolic BP by 10 where systolic BP = 14
visit_df_cleaned['SystolicBP_cleaned'] = visit_df_cleaned['SystolicBP'].apply(lambda x: x * 10 if x == 14 else x)
# Verify that the erroneous systolic BP have been fixed
visit_df_cleaned[visit_df_cleaned['SystolicBP'] < 75]
# Manually check diastolic BP below 40 and clean where necessary (6 rows)
# Observation 24758 appears plausible (diastolic BP of 30)
# The other observations appear to have erroneous diastolic BPs (likely a
# missing 0 at the end as, for example, a diastolic BP of 90 appears much more
# reasonable)
visit_df_cleaned[visit_df_cleaned['DiastolicBP'] < 40].shape[0]
visit_df_cleaned[visit_df_cleaned['DiastolicBP'] < 40]
# Solution: Multiply diastolic BP by 10 where diastolic BP < 30
visit_df_cleaned['DiastolicBP_cleaned'] = visit_df_cleaned['DiastolicBP'].apply(lambda x: x * 10 if x < 30 else x)
# Verify that the erroneous diastolic BP have been fixed
visit_df_cleaned[visit_df_cleaned['DiastolicBP'] < 40]
Range of BP: Systolic BP - Min: 14.00, Max: 301.00 Diastolic BP - Min: 7.00, Max: 175.00
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | Temperature | PhysicianSpecialty | Height_cm | Height_cleaned | SystolicBP_cleaned | DiastolicBP_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 5988 | 74f867b1-2630-45c0-b9e2-7b8d00515cdc | 12bc880a-7829-4f61-83b3-7216c0fe09dd | 2022-02-23 20:59:39 | -08:00 | 62.5 | 218.1 | 38.890 | 132.0 | 9.0 | 17.0 | NaN | Family Practice | 158.750 | 62.5 | 132.0 | 90.0 |
| 5989 | 1df3140f-d425-41f2-b226-5a13757bd713 | 12bc880a-7829-4f61-83b3-7216c0fe09dd | 2022-04-27 16:18:23 | -08:00 | 62.6 | NaN | NaN | 135.0 | 7.0 | 18.0 | 97.0 | Family Practice | 159.004 | 62.6 | 135.0 | 70.0 |
| 24758 | d3f04c46-54ba-4dd4-963b-70d044ab9e14 | 46164d87-7331-477a-b383-41afc957709a | 2020-02-18 14:08:31 | -05:00 | NaN | 102.0 | NaN | 130.0 | 30.0 | 16.0 | NaN | Internal Medicine | NaN | NaN | 130.0 | 30.0 |
| 30609 | 6b845aaa-99ae-4209-80b8-377510445db1 | 58cb121c-d78a-493b-be5d-5523b13098ef | 2020-02-14 00:56:01 | -08:00 | 62.0 | 134.0 | 21.635 | 128.0 | 18.0 | 20.0 | NaN | Unknown | 157.480 | 62.0 | 128.0 | 180.0 |
| 40795 | e7063596-4f33-43c1-9a3c-e9ed53baec55 | 77fcab5a-cb44-4279-a5b8-b0c8a778e1a0 | 2020-01-31 20:43:47 | -05:00 | 58.5 | 131.0 | 24.214 | 155.0 | 14.0 | 20.0 | NaN | Internal Medicine | 148.590 | 58.5 | 155.0 | 140.0 |
| 40797 | c79a4913-8eab-4c63-81be-72c8ffcf1f72 | 77fcab5a-cb44-4279-a5b8-b0c8a778e1a0 | 2020-02-17 18:13:10 | -05:00 | NaN | 122.0 | NaN | 147.0 | 18.0 | 20.0 | NaN | Internal Medicine | NaN | NaN | 147.0 | 180.0 |
Q2a Step 6 - Cleaning of Physician Specialty Data in 'Visit' Table¶
# Q2a Step 6
# There are 3 patients with visits that have missing physician specialty
# They are also missing almost all other information in the 'visit' table
display(visit_df_cleaned[visit_df_cleaned['PhysicianSpecialty'].isna()])
# These 3 patients have other visits
display(visit_df_cleaned[visit_df_cleaned['PatientGuid'].isin(visit_df_cleaned[visit_df_cleaned['PhysicianSpecialty'].isna()]['PatientGuid'].unique())]['PatientGuid'].value_counts())
# Imputing these 3 data points is difficult as, without further information, it
# is not clear what the specialty of these visits are.
# Dropping 3 data points will not be material in such a large dataset with over
# 89000 visits and 26000 diagnoses and these patients' history will still be in
# the dataset since they have other visits.
# Furthermore, missing physician specialty will cause issues later when joining
# with the 'Specialty' table, but it is also difficult to impute these values
# with so little information regarding these visits.
# Solution: Drop these particular visits from the 'Visit' table.
# Drop rows in visit_df_cleaned where physician specialty is missing
visit_df_cleaned = visit_df_cleaned[~visit_df_cleaned['PhysicianSpecialty'].isna()].copy()
# Verify that there are no more rows with missing physician specialty in
# visit_df_cleaned
print(f"Number of rows remaining in visit_df_cleaned with blank specialty: {len(visit_df_cleaned[visit_df_cleaned['PhysicianSpecialty'].isna()])}")
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | Temperature | PhysicianSpecialty | Height_cm | Height_cleaned | SystolicBP_cleaned | DiastolicBP_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 8143 | b4a761e6-6f52-446f-bc63-f283e8e0ca63 | 19506490-0d26-474c-ba3f-be2547e13405 | 2022-01-08 14:56:55 | -05:00 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 80186 | f7ce62a6-4afa-4f16-8e52-88f9a4dba0bf | e5ce4e28-63e4-4102-9d4e-8e6071cb7b16 | 2021-04-20 13:12:24 | -05:00 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| 82459 | ebc986ba-d325-4714-b03b-d2b220c74659 | ec3e1ad3-84ad-41c5-9bc3-12f10e26af74 | 2022-08-15 19:58:01 | -08:00 | 61.5 | 181.0 | NaN | NaN | 79.0 | 16.0 | NaN | NaN | 156.21 | 61.5 | NaN | 79.0 |
| count | |
|---|---|
| PatientGuid | |
| ec3e1ad3-84ad-41c5-9bc3-12f10e26af74 | 43 |
| 19506490-0d26-474c-ba3f-be2547e13405 | 17 |
| e5ce4e28-63e4-4102-9d4e-8e6071cb7b16 | 12 |
Number of rows remaining in visit_df_cleaned with blank specialty: 0
Q2a Step 7 - Cleaning of Abnormal Values in 'LabObservation' Table¶
# Q2a Step 7
random_seed = 41399
labobservation_df_cleaned = labobservation_df.copy()
# Function to extract lower bound of reference range where possible
def extract_lower_bound(reference_range):
if pd.isna(reference_range):
return np.nan
rr_str = str(reference_range).strip()
# Pattern for range like '3.5-5.0'
range_match = re.match(r'^\s*([\-+]?\d+(\.\d+)?)\s*-\s*([\-+]?\d+(\.\d+)?)\s*$', rr_str)
if range_match:
try:
return float(range_match.group(1))
except ValueError:
return np.nan
# Pattern for '>=3'
gte_match = re.match(r'^\s*>=\s*([\-+]?\d+(\.\d+)?)\s*$', rr_str)
if gte_match:
try:
return float(gte_match.group(1))
except ValueError:
return np.nan
# Pattern for '>3'
gt_match = re.match(r'^\s*>\s*([\-+]?\d+(\.\d+)?)\s*$', rr_str)
if gt_match:
try:
return float(gt_match.group(1))
except ValueError:
return np.nan
# Return NaN if no pattern matches
return np.nan
# Function to extract upper bound of reference range where possible
def extract_upper_bound(reference_range):
if pd.isna(reference_range):
return np.nan
rr_str = str(reference_range).strip()
# Pattern for range like '3.5-5.0'
range_match = re.match(r'^\s*([\-+]?\d+(\.\d+)?)\s*-\s*([\-+]?\d+(\.\d+)?)\s*$', rr_str)
if range_match:
try:
return float(range_match.group(3))
except ValueError:
return np.nan
# Pattern for '<3' or '< 3'
lt_match = re.match(r'^\s*<\s*([\-+]?\d+(\.\d+)?)\s*$', rr_str)
if lt_match:
try:
return float(lt_match.group(1))
except ValueError:
return np.nan
# Return NaN if no pattern matches
return np.nan
# Apply the functions on labobservation_df_cleaned
labobservation_df_cleaned['ReferenceRange_LB'] = labobservation_df_cleaned['ReferenceRange'].apply(extract_lower_bound)
labobservation_df_cleaned['ReferenceRange_UB'] = labobservation_df_cleaned['ReferenceRange'].apply(extract_upper_bound)
def clean_abnormal_flags(df: pd.DataFrame) -> pd.Series:
# Initialise a result series with False, assuming no abnormality by default
is_abnormal = pd.Series(False, index = df.index)
# 1. If an explicit boolean is present in [IsAbnormalValue], use it
# directly.
is_abnormal = is_abnormal | df['IsAbnormalValue']
# 2. If [AbnormalFlags] is not empty, treat as abnormal.
is_abnormal = is_abnormal | df['AbnormalFlags'].notna()
# 3. Check if observation value is outside the reference range.
abnormal_by_range = (
(df['ObservationValue'].notna() & df['ObservationValue'].lt(df['ReferenceRange_LB']) & df['ReferenceRange_LB'].notna()) | (df['ObservationValue'].notna() & df['ObservationValue'].gt(df['ReferenceRange_UB']) & df['ReferenceRange_UB'].notna())
)
# Combine with previous checks
is_abnormal = is_abnormal | abnormal_by_range
# Return the final result
return is_abnormal
# Apply the function to create a new column
labobservation_df_cleaned["IsAbnormalValue_cleaned"] = clean_abnormal_flags(labobservation_df_cleaned)
# Verify that the above cleaning works as expected by sampling 5 random rows
display(labobservation_df_cleaned.sample(5, random_state = random_seed))
# Get value counts for both columns and concatenate into a single DataFrame
abnormal_value_counts_original = labobservation_df_cleaned['IsAbnormalValue'].value_counts(dropna = False)
abnormal_value_counts_cleaned = labobservation_df_cleaned['IsAbnormalValue_cleaned'].value_counts(dropna = False)
abnormal_counts_comparison = pd.concat([abnormal_value_counts_original, abnormal_value_counts_cleaned], axis = 1)
# Rename the columns for clarity
abnormal_counts_comparison.columns = ['IsAbnormalValue', 'IsAbnormalValue_cleaned']
# Calculate the proportion of True for each column
proportion_true_original = abnormal_value_counts_original.get(True, 0) / abnormal_value_counts_original.sum() * 100
proportion_true_cleaned = abnormal_value_counts_cleaned.get(True, 0) / abnormal_value_counts_cleaned.sum() * 100
proportion_true = pd.Series({
'IsAbnormalValue': proportion_true_original,
'IsAbnormalValue_cleaned': proportion_true_cleaned
}, name = '% of True')
# Display the distributions
print("Comparison of distribution of IsAbnormalValue and IsAbnormalValue_cleaned:")
display(pd.concat([abnormal_counts_comparison, pd.DataFrame([proportion_true])]))
# Condense the 'LabObservation' table by LabResultGuid for more meaningful
# interpretation by the model
labobservation_df_cleaned_v2 = labobservation_df_cleaned.groupby('LabResultGuid')['IsAbnormalValue_cleaned'].max().rename('AnyAbnormalValue_cleaned').reset_index()
# Verify that a lab result shows normal if all of the tests are normal
display(labobservation_df_cleaned[labobservation_df_cleaned['LabResultGuid'] == '00635d82-4552-4b6b-a68c-ec55fb26c683'])
display(labobservation_df_cleaned_v2[labobservation_df_cleaned_v2['LabResultGuid'] == '00635d82-4552-4b6b-a68c-ec55fb26c683'])
# Verify that a lab result shows abnormal if any one of the tests is abnormal
display(labobservation_df_cleaned[labobservation_df_cleaned['LabResultGuid'] == '002fd0ba-e24a-456c-83a3-fa997042a895'])
display(labobservation_df_cleaned_v2[labobservation_df_cleaned_v2['LabResultGuid'] == '002fd0ba-e24a-456c-83a3-fa997042a895'])
# Verify that each lab result appears only once in labobservation_df_cleaned_v2
print(f"Number of LabResultGuids appearing more than once in labobservation_df_cleaned_v2: {len(labobservation_df_cleaned_v2['LabResultGuid'].value_counts()[labobservation_df_cleaned_v2['LabResultGuid'].value_counts() > 1])}")
| LabObservationGuid | LabResultGuid | HL7Text | ObservationValue | Units | ReferenceRange | AbnormalFlags | IsAbnormalValue | ReferenceRange_LB | ReferenceRange_UB | IsAbnormalValue_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 4632 | f3330061-04e1-4689-8a7e-283cdf84b303 | 26618943-511a-4d04-ab1e-00d924cd3579 | Globulin | 2.4 | g/dL | 1.1-2.5 | NaN | False | 1.1 | 2.5 | False |
| 98262 | 063486b3-f163-41ef-8a17-ec7c531bb5f5 | f2540672-5721-4208-86c2-d2e20d80473f | Protein Total | 6.8 | g/dL | 0.0-1.2 | NaN | False | 0.0 | 1.2 | True |
| 48353 | ec467dd2-4d23-43fe-8038-ee07e4deb1af | a4d96ae7-8e61-4b61-bfbf-c71d687c2e24 | Absolute Lymph | 102.0 | uIU/ml | NaN | NaN | False | NaN | NaN | False |
| 128997 | 8e959451-c08b-4b78-92ab-e218108ed622 | 90db37fb-2dd3-49d5-9d95-d48ef077fd3f | Triglyceride | 7.2 | x10E3/uL | NaN | NaN | False | NaN | NaN | False |
| 87818 | 24a10967-1629-44c4-8ac3-ff986e15c57f | 9e769f46-0a9a-4fa3-83db-2066e5f92552 | Protein Total | 6.3 | g/dL | 0.0-1.2 | NaN | False | 0.0 | 1.2 | True |
Comparison of distribution of IsAbnormalValue and IsAbnormalValue_cleaned:
| IsAbnormalValue | IsAbnormalValue_cleaned | |
|---|---|---|
| False | 126014.000000 | 74055.000000 |
| True | 8255.000000 | 60214.000000 |
| % of True | 6.148106 | 44.845795 |
| LabObservationGuid | LabResultGuid | HL7Text | ObservationValue | Units | ReferenceRange | AbnormalFlags | IsAbnormalValue | ReferenceRange_LB | ReferenceRange_UB | IsAbnormalValue_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 84803 | ef111fe2-1057-4a60-9009-5603657fb430 | 00635d82-4552-4b6b-a68c-ec55fb26c683 | Progesterone | 3.4 | mg/dL | 0.9-5.2 | NaN | False | 0.9 | 5.2 | False |
| 84804 | 28094793-dc11-4851-a05f-4594337ecc35 | 00635d82-4552-4b6b-a68c-ec55fb26c683 | Bilirubin | NaN | NaN | NaN | NaN | False | NaN | NaN | False |
| LabResultGuid | AnyAbnormalValue_cleaned | |
|---|---|---|
| 15 | 00635d82-4552-4b6b-a68c-ec55fb26c683 | False |
| LabObservationGuid | LabResultGuid | HL7Text | ObservationValue | Units | ReferenceRange | AbnormalFlags | IsAbnormalValue | ReferenceRange_LB | ReferenceRange_UB | IsAbnormalValue_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 124966 | 9530a070-69b0-45b9-8098-7ffb6706ed82 | 002fd0ba-e24a-456c-83a3-fa997042a895 | Albumin / Globulin Ratio | NaN | NaN | NaN | NaN | False | NaN | NaN | False |
| 124967 | 7e0393fe-b6fa-46d6-950b-9efec8127bcb | 002fd0ba-e24a-456c-83a3-fa997042a895 | Bacteria | NaN | NaN | NaN | NaN | False | NaN | NaN | False |
| 124968 | e49bda39-7642-432c-a288-5bb5a1d81916 | 002fd0ba-e24a-456c-83a3-fa997042a895 | Bilirubin | 0.4 | mg/dL | 97-108 | NaN | False | 97.0 | 108.0 | True |
| LabResultGuid | AnyAbnormalValue_cleaned | |
|---|---|---|
| 5 | 002fd0ba-e24a-456c-83a3-fa997042a895 | True |
Number of LabResultGuids appearing more than once in labobservation_df_cleaned_v2: 0
Q2a Step 8 - Cleaning of Description Data in 'Smoking' Table¶
# Q2a Step 8
smoking_df_cleaned = smoking_df.copy()
# Display the count of each description
display(smoking_df_cleaned.groupby(['Description']).size().reset_index(name = 'Count'))
# Fix typographical error in 'cigaretttes'
smoking_df_cleaned['Description_cleaned'] = smoking_df_cleaned['Description'].str.replace('cigaretttes', 'cigarettes', regex = False)
# Verify that the typographical error in 'cigaretttes' has been fixed
display(smoking_df_cleaned.groupby(['Description_cleaned']).size().reset_index(name = 'Count'))
| Description | Count | |
|---|---|---|
| 0 | 0 cigarettes per day (non-smoker or less than ... | 204 |
| 1 | 0 cigarettes per day (previous smoker) | 54 |
| 2 | 0 cigaretttes per day (non-smoker or less than... | 678 |
| 3 | 0 cigaretttes per day (previous smoker) | 836 |
| 4 | 1-2 packs per day | 100 |
| 5 | 2 or more packs per day | 7 |
| 6 | Few (1-3) cigarettes per day | 65 |
| 7 | Up to 1 pack per day | 147 |
| Description_cleaned | Count | |
|---|---|---|
| 0 | 0 cigarettes per day (non-smoker or less than ... | 882 |
| 1 | 0 cigarettes per day (previous smoker) | 890 |
| 2 | 1-2 packs per day | 100 |
| 3 | 2 or more packs per day | 7 |
| 4 | Few (1-3) cigarettes per day | 65 |
| 5 | Up to 1 pack per day | 147 |
Q2a Step 9 - Cleaning of 'Specialty' Table¶
# Q2a Step 9
specialty_df_cleaned = specialty_df.copy()
# Drop 'empty' row
display(specialty_df_cleaned[specialty_df_cleaned['PhysicianSpecialty'].isna()])
specialty_df_cleaned = specialty_df_cleaned.dropna(subset = ['PhysicianSpecialty'])
# Verify that the empty row has been dropped
display(specialty_df_cleaned[specialty_df_cleaned['PhysicianSpecialty'].isna()])
| PhysicianSpecialty | Specialty | SpecialtyGroup | |
|---|---|---|---|
| 64 | NaN | NaN | NaN |
| PhysicianSpecialty | Specialty | SpecialtyGroup |
|---|
Q2a Step 10 - Cleaning of Specialty Group data in 'Specialty' Table¶
# Q2a Step 10
# Manually impute missing SpecialtyGroup based on existing groupings
display(specialty_df_cleaned['SpecialtyGroup'].unique())
display(specialty_df_cleaned[specialty_df_cleaned['SpecialtyGroup'].isna()])
specialty_df_cleaned['SpecialtyGroup_cleaned'] = specialty_df_cleaned['SpecialtyGroup']
# Set SpecialtyGroup_cleaned = 'Therapeutic' where PhysicianSpecialty = 'Massage
# Therapy' or 'Speech Therapy'
# Set SpecialtyGroup_cleaned = 'Unknown' where PhysicianSpecialty = 'Unknown'
# Else set SpecialtyGroup_cleaned = 'Medicine'
specialty_df_cleaned.loc[specialty_df_cleaned['PhysicianSpecialty'].isin(['Massage Therapy', 'Speech Therapy']), 'SpecialtyGroup_cleaned'] = 'Therapeutic'
specialty_df_cleaned.loc[specialty_df_cleaned['PhysicianSpecialty'] == 'Unknown', 'SpecialtyGroup_cleaned'] = 'Unknown'
specialty_df_cleaned.loc[specialty_df_cleaned['SpecialtyGroup_cleaned'].isna(), 'SpecialtyGroup_cleaned'] = 'Medicine'
# Check again for any further missing SpecialtyGroup_cleaned
display(specialty_df_cleaned[specialty_df_cleaned['SpecialtyGroup_cleaned'].isna()])
array(['Medicine', 'Surgery', nan, 'Diagnostic and Therapeutic',
'Diagnostic', 'Therapeutic'], dtype=object)
| PhysicianSpecialty | Specialty | SpecialtyGroup | |
|---|---|---|---|
| 5 | Alternative Medical Practitioner | NaN | NaN |
| 8 | Chiropractic | NaN | NaN |
| 10 | Clinical Pharmacology | Medical Research | NaN |
| 26 | Massage Therapy | NaN | NaN |
| 28 | Naturopathy - Acupuncture | NaN | NaN |
| 35 | Nutrition | Dietetics | NaN |
| 38 | Optometry | NaN | NaN |
| 49 | Preventive Medicine | NaN | NaN |
| 52 | Psychology | NaN | NaN |
| 56 | Speech Therapy | NaN | NaN |
| 57 | Sports Medicine | NaN | NaN |
| 63 | Unknown | NaN | NaN |
| PhysicianSpecialty | Specialty | SpecialtyGroup | SpecialtyGroup_cleaned |
|---|
Q2a Step 11 - Cleaning of Quantity Data in 'Prescription' Table¶
# Q2a Step 11
prescription_df_cleaned = prescription_df.copy()
# Show minimum and maximum quantity values
print(f"Minimum Quantity: {prescription_df_cleaned['Quantity'].min()}")
print(f"Maximum Quantity: {prescription_df_cleaned['Quantity'].max()}")
# Show top 10 unique quantity values in descending order
print("\nTop 10 quantities:")
display(np.sort(prescription_df_cleaned['Quantity'].unique())[::-1][:10])
# Count rows with quantity of 0: 661
print(f"\nNumber of rows with quantity of 0: {len(prescription_df[prescription_df['Quantity'] == 0])}")
# One row with quantity of 90023
display(prescription_df_cleaned[prescription_df_cleaned['Quantity'] == 90023])
# Does not make sense to have a prescription with quantity of 0 or 90023
# Solution: Drop these rows
prescription_df_cleaned = prescription_df_cleaned[(prescription_df_cleaned['Quantity'] > 0) & (prescription_df_cleaned['Quantity'] < 90023)]
# Verify that there are no more rows where quantity is 0 or 90023
display(prescription_df_cleaned[(prescription_df_cleaned['Quantity'] == 0) | (prescription_df_cleaned['Quantity'] == 90023)])
Minimum Quantity: 0 Maximum Quantity: 90023 Top 10 quantities:
array([90023, 6014, 3784, 3085, 3030, 2814, 2702, 2258, 1714,
1200])
Number of rows with quantity of 0: 661
| PrescriptionGuid | PatientGuid | Timestamp | tz_offset | Quantity | NumberOfRefills | RefillAsNeeded | GenericAllowed | NdcCode | MedicationName | MedicationStrength | Schedule | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 78931 | 78cc5bd0-687f-4db7-b936-af27753268a6 | ba3c36e6-93b3-4925-8f67-195caedffc0f | 2021-10-20 18:45:05 | -06:00 | 90023 | 5 | False | True | 0378-5209 | AmLODIPine Besylate (amLODIPine) oral tablet | 5 mg | NaN |
| PrescriptionGuid | PatientGuid | Timestamp | tz_offset | Quantity | NumberOfRefills | RefillAsNeeded | GenericAllowed | NdcCode | MedicationName | MedicationStrength | Schedule |
|---|
Q2a Step 12 - Cleaning of Refills Data in 'Prescription' Table¶
# Q2a Step 12
# Show minimum and maximum refill values
print(f"Minimum Refills: {prescription_df_cleaned['NumberOfRefills'].min()}")
print(f"Maximum Refills: {prescription_df_cleaned['NumberOfRefills'].max()}")
# Show top 5 unique refills values in descending order
print("\nTop 5 quantities:")
display(np.sort(prescription_df_cleaned['NumberOfRefills'].unique())[::-1][:5])
# Count rows with refills as 110 or 120: 3
print(f"\nNumber of rows with refills greater than 45: {len(prescription_df[(prescription_df['NumberOfRefills'] == 110) | (prescription_df['NumberOfRefills'] == 120)])}")
# Does not make sense to have a prescription with such high number of refills
# Solution: Drop these rows
prescription_df_cleaned = prescription_df_cleaned[prescription_df_cleaned['NumberOfRefills'] <= 45]
# Verify that there are no more rows where refills is 110 or 120
display(prescription_df_cleaned[(prescription_df_cleaned['NumberOfRefills'] == 110) | (prescription_df_cleaned['NumberOfRefills'] == 120)])
Minimum Refills: 0 Maximum Refills: 120 Top 5 quantities:
array([120, 110, 45, 30, 14])
Number of rows with refills greater than 45: 3
| PrescriptionGuid | PatientGuid | Timestamp | tz_offset | Quantity | NumberOfRefills | RefillAsNeeded | GenericAllowed | NdcCode | MedicationName | MedicationStrength | Schedule |
|---|
Q2a Step 13 - Cleaning of 'Prescription' Table¶
# Q2a Step 13
# Check for PatientGuids not present in 'Patient' table (84905 rows or ~78%)
prescription_df_cleaned[~prescription_df_cleaned['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])].shape[0]
prescription_df_cleaned[~prescription_df_cleaned['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])].shape[0] / len(prescription_df_cleaned)
# Drop rows where PatientGuid is not present in 'Patient' table
prescription_df_cleaned_v2 = prescription_df_cleaned[prescription_df_cleaned['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])]
# Verify that no more orphaned rows exist
prescription_df_cleaned_v2[~prescription_df_cleaned_v2['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])].shape[0]
0
Q2a Step 14 - Checking of Table Join Keys¶
# Q2a Step 14
# 'Patient' table is the parent table, so no need to check that. Instead, all
# other tables will be checked against this table.
# Check if all PatientGuids in diagnosis_df_cleaned, visit_df_cleaned,
# labresult_df, smoking_df_cleaned, and prescription_df_cleaned_v2 exist in
# patient_df_cleaned
dataframes_to_check = {
'diagnosis_df_cleaned': diagnosis_df_cleaned,
'visit_df_cleaned': visit_df_cleaned,
'labresult_df': labresult_df,
'smoking_df_cleaned': smoking_df_cleaned,
'prescription_df_cleaned_v2': prescription_df_cleaned_v2
}
for df_name, df_to_check in dataframes_to_check.items():
if df_to_check[~df_to_check['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])].empty:
print(f"All PatientGuids in {df_name} exist in patient_df_cleaned.")
else:
missing_guids_count = df_to_check[~df_to_check['PatientGuid'].isin(patient_df_cleaned['PatientGuid'])]['PatientGuid'].nunique()
total_guids_to_check = df_to_check['PatientGuid'].nunique()
proportion_missing = missing_guids_count / total_guids_to_check * 100
print(f"{proportion_missing:.2f}% of PatientGuids in {df_name} are not found in patient_df_cleaned.")
# Check if all LabResultGuids in labobservation_df_cleaned_v2 and pathology_df
# exist in labresult_df
if labobservation_df_cleaned_v2[~labobservation_df_cleaned_v2['LabResultGuid'].isin(labresult_df['LabResultGuid'])].empty:
print("All LabResultGuids in labobservation_df_cleaned_v2 exist in labresult_df.")
else:
print(f"{labobservation_df_cleaned_v2[~labobservation_df_cleaned_v2['LabResultGuid'].isin(labresult_df['LabResultGuid'])]['LabResultGuid'].nunique() / labobservation_df_cleaned_v2['LabResultGuid'].nunique() * 100:.2f}% of LabResultGuids in labobservation_df_cleaned_v2 are not found in labresult_df.")
if pathology_df[~pathology_df['LabResultGuid'].isin(labresult_df['LabResultGuid'])].empty:
print("All LabResultGuids in pathology_df exist in labresult_df.")
else:
print(f"{pathology_df[~pathology_df['LabResultGuid'].isin(labresult_df['LabResultGuid'])]['LabResultGuid'].nunique() / pathology_df['LabResultGuid'].nunique() * 100:.2f}% of LabResultGuids in pathology_df are not found in labresult_df.")
# Check if all ICD-9 codes in diagnosis_df_cleaned exist in icd9_df_cleaned_v2
if diagnosis_df_cleaned[~diagnosis_df_cleaned['ICD9Code_cleaned'].isin(icd9_df_cleaned_v2['ICD9Code_cleaned'])].empty:
print("All ICD-9 codes in diagnosis_df_cleaned exist in icd9_df_cleaned_v2.")
else:
print(f"{diagnosis_df_cleaned[~diagnosis_df_cleaned['ICD9Code_cleaned'].isin(icd9_df_cleaned_v2['ICD9Code_cleaned'])]['ICD9Code_cleaned'].nunique() / diagnosis_df_cleaned['ICD9Code_cleaned'].nunique() * 100:.2f}% of ICD-9 codes in diagnosis_df_cleaned are not found in icd9_df_cleaned_v2.")
# Check if all physician specialties in visit_df_cleaned exist in
# specialty_df_cleaned
if visit_df_cleaned[~visit_df_cleaned['PhysicianSpecialty'].isin(specialty_df_cleaned['PhysicianSpecialty'])].empty:
print("All physician specialties in visit_df_cleaned exist in specialty_df_cleaned.")
else:
print(f"{visit_df_cleaned[~visit_df_cleaned['PhysicianSpecialty'].isin(specialty_df_cleaned['PhysicianSpecialty'])]['PhysicianSpecialty'].nunique() / visit_df_cleaned['PhysicianSpecialty'].nunique() * 100:.2f}% of physician specialties in visit_df_cleaned are not found in specialty_df_cleaned.")
# Check if all state codes in patient_df_cleaned exist in
# statedetails_df
if patient_df_cleaned[~patient_df_cleaned['StateCode'].isin(statedetails_df['StateCode'])].empty:
print("All state codes in patient_df_cleaned exist in statedetails_df.")
else:
print(f"{patient_df_cleaned[~patient_df_cleaned['StateCode'].isin(statedetails_df['StateCode'])]['StateCode'].nunique() / patient_df_cleaned['StateCode'].nunique() * 100:.2f}% of state codes in patient_df_cleaned are not found in statedetails_df.")
All PatientGuids in diagnosis_df_cleaned exist in patient_df_cleaned. All PatientGuids in visit_df_cleaned exist in patient_df_cleaned. All PatientGuids in labresult_df exist in patient_df_cleaned. All PatientGuids in smoking_df_cleaned exist in patient_df_cleaned. All PatientGuids in prescription_df_cleaned_v2 exist in patient_df_cleaned. All LabResultGuids in labobservation_df_cleaned_v2 exist in labresult_df. All LabResultGuids in pathology_df exist in labresult_df. All ICD-9 codes in diagnosis_df_cleaned exist in icd9_df_cleaned_v2. All physician specialties in visit_df_cleaned exist in specialty_df_cleaned. All state codes in patient_df_cleaned exist in statedetails_df.
Q2a Step 15 - Splitting of Data into Training, Validation, and Testing Sets¶
# Q2a Step 15
# Get unique PatientGuids from patient_df_cleaned
patient_keys = patient_df_cleaned.drop_duplicates(subset = ['PatientGuid'])[['PatientGuid']].copy()
# Transform each PatientGuid into a number between 0 and 1 following an
# approximate uniform distribution in a replicable way (i.e. pseudo-random)
hex = patient_keys['PatientGuid'].apply(lambda x: int(hashlib.md5(x.encode()).hexdigest(), 16) % 1000) / 1000.0
# Split the patients into approximately 70% for training, 15% for validation, and 15% for testing
bins = pd.cut(hex, bins = [-0.001, 0.70, 0.85, 1], labels = ['train', 'val', 'test'])
# Apply the split
patient_split = pd.DataFrame({'PatientGuid': patient_keys['PatientGuid'], "set": bins.astype(str)})
# Verify the proportions for training, validation, and testing
display(patient_split['set'].value_counts(normalize = True).reindex(['train', 'val', 'test']))
# Function to apply splitting to other tables with 'PatientGuid' as a key
def merge_split(df: pd.DataFrame) -> pd.DataFrame:
if 'PatientGuid' in df.columns:
return df.merge(patient_split, on = 'PatientGuid', how = 'left')
return df
# List of tables to apply merge_split to
dataframes_to_split = {
'patient_df_cleaned': patient_df_cleaned,
'diagnosis_df_cleaned': diagnosis_df_cleaned,
'visit_df_cleaned': visit_df_cleaned,
'labresult_df': labresult_df,
'smoking_df_cleaned': smoking_df_cleaned,
'prescription_df_cleaned_v2': prescription_df_cleaned_v2
}
# Apply merge_split and store in new DataFrames with "_split" suffix
split_dataframes = {}
for name, df in dataframes_to_split.items():
new_name = name + "_split"
split_dataframes[new_name] = df.pipe(merge_split)
# Assign the new DataFrames to variables in the global scope and check the
# proportions of train, val, and test in each new DataFrame
proportions_dict = {}
for name, df in split_dataframes.items():
globals()[name] = df
proportions_dict[name] = df['set'].value_counts(normalize = True).reindex(['train', 'val', 'test'])
print("\nProportions of train, val, and test across split DataFrames:")
display(pd.DataFrame(proportions_dict))
| proportion | |
|---|---|
| set | |
| train | 0.709750 |
| val | 0.136947 |
| test | 0.153303 |
Proportions of train, val, and test across split DataFrames:
| patient_df_cleaned_split | diagnosis_df_cleaned_split | visit_df_cleaned_split | labresult_df_split | smoking_df_cleaned_split | prescription_df_cleaned_v2_split | |
|---|---|---|---|---|---|---|
| set | ||||||
| train | 0.713183 | 0.709839 | 0.712454 | 0.709170 | 0.715925 | 0.712572 |
| val | 0.134308 | 0.139702 | 0.132562 | 0.133442 | 0.135342 | 0.136469 |
| test | 0.152510 | 0.150460 | 0.154984 | 0.157388 | 0.148733 | 0.150959 |
Q2a Step 16 - Imputation of Missing Values in 'Visit' Table¶
Height¶
# Q2a Step 16
visit_df_cleaned_split_v2 = visit_df_cleaned_split.copy()
# Check missingness of height data by set
print("Percentage of missing height and weight data by set:")
display(visit_df_cleaned_split_v2.groupby('set')['Height_cleaned'].apply(lambda x: x.isnull().sum() / len(x) * 100).rename('Missing Height (in %)').reindex(['train', 'val', 'test']))
# Add indicator column for rows with missing height data
visit_df_cleaned_split_v2['Height_missing'] = visit_df_cleaned_split_v2['Height_cleaned'].isnull().astype(int)
# Verify that gender does not change for each patient
print(f"\nPatients with more than one gender: {len(patient_df_cleaned.groupby('PatientGuid')['Gender'].nunique()[patient_df_cleaned.groupby('PatientGuid')['Gender'].nunique() > 1])}")
# Retrieve gender from patient_df_cleaned
visit_df_cleaned_split_v2 = visit_df_cleaned_split_v2.merge(patient_df_cleaned[['PatientGuid', 'Gender']].drop_duplicates(), on = 'PatientGuid', how = 'left')
# Compare rows before and after to ensure 'visit' table structure has not
# changed
print("\nNumber of rows before and after retrieving gender information:")
display(visit_df_cleaned_split_v2.shape)
display(visit_df_cleaned_split.shape)
# Calculate median height by gender for the training set
median_height_by_gender = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['set'] == 'train'].groupby('Gender')['Height_cleaned'].median()
# Impute missing height values using the median by gender from the training set
visit_df_cleaned_split_v2['Height_imputed'] = visit_df_cleaned_split_v2.apply(lambda row: median_height_by_gender.get(row['Gender'], row['Height_cleaned']) if pd.isnull(row['Height_cleaned']) else row['Height_cleaned'], axis = 1)
# Verify that there are no more missing height data
print("\nMissing height data after imputation:")
print(visit_df_cleaned_split_v2[['Height_imputed']].isnull().sum())
# Verify that the indicator column was populated correctly
# Rows with Height_missing == 1 should not have value in Height_cleaned
print(f"\nNumber of rows with Height_missing == 1 and non-missing Height_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Height_missing'] == 1) & (visit_df_cleaned_split_v2['Height_cleaned'].notnull())].shape[0]}")
# Rows with Height_missing == 0 should have value in Height_cleaned
print(f"\nNumber of rows with Height_missing == 0 and missing Height_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Height_missing'] == 0) & (visit_df_cleaned_split_v2['Height_cleaned'].isnull())].shape[0]}")
# Verify that height did not change for all rows without missing height
print(f"\nNumber of rows without missing height but different height value before and after cleaning: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Height_missing'] == 0) & (visit_df_cleaned_split_v2['Height_imputed'] != visit_df_cleaned_split_v2['Height_cleaned'])].shape[0]}")
# Get unique imputed height by gender where Height_missing == 1 to verify that
# imputation was done correctly
print("\nMedian height by gender from training set:")
display(median_height_by_gender)
print("\nUnique imputed height by gender where Height_missing == 1:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['Height_missing'] == 1].groupby('Gender')['Height_imputed'].unique())
Percentage of missing height and weight data by set:
| Missing Height (in %) | |
|---|---|
| set | |
| train | 64.725424 |
| val | 65.493315 |
| test | 62.343490 |
Patients with more than one gender: 0 Number of rows before and after retrieving gender information:
(89151, 19)
(89151, 17)
Missing height data after imputation: Height_imputed 0 dtype: int64 Number of rows with Height_missing == 1 and non-missing Height_cleaned: 0 Number of rows with Height_missing == 0 and missing Height_cleaned: 0 Number of rows without missing height but different height value before and after cleaning: 0 Median height by gender from training set:
| Height_cleaned | |
|---|---|
| Gender | |
| F | 65.0 |
| M | 67.5 |
Unique imputed height by gender where Height_missing == 1:
| Height_imputed | |
|---|---|
| Gender | |
| F | [65.0] |
| M | [67.5] |
Weight¶
# Q2a Step 16
# Check missingness of weight data by set
print("Percentage of missing weight data by set:")
display(visit_df_cleaned_split_v2.groupby('set')['Weight'].apply(lambda x: x.isnull().sum() / len(x) * 100).rename('Missing Weight (in %)').reindex(['train', 'val', 'test']))
# Add indicator column for rows with missing weight data
visit_df_cleaned_split_v2['Weight_missing'] = visit_df_cleaned_split_v2['Weight'].isnull().astype(int)
# Calculate median weight by gender for the training set
median_weight_by_gender = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['set'] == 'train'].groupby('Gender')['Weight'].median()
# Impute missing weight values using the median by gender from the training set
visit_df_cleaned_split_v2['Weight_imputed'] = visit_df_cleaned_split_v2.apply(lambda row: median_weight_by_gender.get(row['Gender'], row['Weight']) if pd.isnull(row['Weight']) else row['Weight'], axis = 1)
# Verify that there are no more missing weight data
print("\nMissing weight data after imputation:")
print(visit_df_cleaned_split_v2[['Weight_imputed']].isnull().sum())
# Verify that the indicator column was populated correctly
# Rows with Weight_missing == 1 should not have value in Weight
print(f"\nNumber of rows with Height_missing == 1 and non-missing Weight: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Weight_missing'] == 1) & (visit_df_cleaned_split_v2['Weight'].notnull())].shape[0]}")
# Rows with Weight_missing == 0 should have value in Weight
print(f"\nNumber of rows with Weight_missing == 0 and missing Weight: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Weight_missing'] == 0) & (visit_df_cleaned_split_v2['Weight'].isnull())].shape[0]}")
# Verify that weight did not change for all rows without missing weight
print(f"\nNumber of rows with missing weight but different weight value before and after cleaning: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['Weight_missing'] == 0) & (visit_df_cleaned_split_v2['Weight_imputed'] != visit_df_cleaned_split_v2['Weight'])].shape[0]}")
# Get unique imputed weight by gender where Weight_missing == 1 to verify that
# imputation was done correctly
print("\nMedian weight by gender from training set:")
display(median_weight_by_gender)
print("\nUnique imputed weight by gender where Weight_missing == 1:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['Weight_missing'] == 1].groupby('Gender')['Weight_imputed'].unique())
Percentage of missing weight data by set:
| Missing Weight (in %) | |
|---|---|
| set | |
| train | 52.682789 |
| val | 53.545439 |
| test | 52.826229 |
Missing weight data after imputation: Weight_imputed 0 dtype: int64 Number of rows with Height_missing == 1 and non-missing Weight: 0 Number of rows with Weight_missing == 0 and missing Weight: 0 Number of rows with missing weight but different weight value before and after cleaning: 0 Median weight by gender from training set:
| Weight | |
|---|---|
| Gender | |
| F | 175.00 |
| M | 191.25 |
Unique imputed weight by gender where Weight_missing == 1:
| Weight_imputed | |
|---|---|
| Gender | |
| F | [175.0] |
| M | [191.25] |
Systolic Blood Pressure¶
# Q2a Step 16
# Check missingness of systolic BP data by set
print("Percentage of missing systolic BP data by set:")
display(visit_df_cleaned_split_v2.groupby('set')['SystolicBP_cleaned'].apply(lambda x: x.isnull().sum() / len(x) * 100).rename('Missing Systolic BP (in %)').reindex(['train', 'val', 'test']))
# Add indicator column for rows with missing systolic BP data
visit_df_cleaned_split_v2['SystolicBP_missing'] = visit_df_cleaned_split_v2['SystolicBP_cleaned'].isnull().astype(int)
# Calculate median systolic BP by gender for the training set
median_systolicBP_by_gender = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['set'] == 'train'].groupby('Gender')['SystolicBP_cleaned'].median()
# Impute missing systolic BP values using the median by gender from the training
# set
visit_df_cleaned_split_v2['SystolicBP_imputed'] = visit_df_cleaned_split_v2.apply(lambda row: median_systolicBP_by_gender.get(row['Gender'], row['SystolicBP_cleaned']) if pd.isnull(row['SystolicBP_cleaned']) else row['SystolicBP_cleaned'], axis = 1)
# Verify that there are no more missing systolic BP data
print("\nMissing systolic BP data after imputation:")
print(visit_df_cleaned_split_v2[['SystolicBP_imputed']].isnull().sum())
# Verify that the indicator column was populated correctly
# Rows with SystolicBP_missing == 1 should not have value in SystolicBP_cleaned
print(f"\nNumber of rows with SystolicBP_missing == 1 and non-missing SystolicBP_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['SystolicBP_missing'] == 1) & (visit_df_cleaned_split_v2['SystolicBP_cleaned'].notnull())].shape[0]}")
# Rows with SystolicBP_missing == 0 should have value in SystolicBP_cleaned
print(f"\nNumber of rows with SystolicBP_missing == 0 and missing SystolicBP_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['SystolicBP_missing'] == 0) & (visit_df_cleaned_split_v2['SystolicBP_cleaned'].isnull())].shape[0]}")
# Verify that SystolicBP did not change for all rows without missing SystolicBP
print(f"\nNumber of rows without missing SystolicBP but different SystolicBP value before and after cleaning: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['SystolicBP_missing'] == 0) & (visit_df_cleaned_split_v2['SystolicBP_imputed'] != visit_df_cleaned_split_v2['SystolicBP_cleaned'])].shape[0]}")
# Get unique imputed SystolicBP by gender where SystolicBP_missing == 1 to
# verify that imputation was done correctly
print("\nMedian Systolic BP by gender from training set:")
display(median_systolicBP_by_gender)
print("\nUnique imputed SystolicBP by gender where SystolicBP_missing == 1:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['SystolicBP_missing'] == 1].groupby('Gender')['SystolicBP_imputed'].unique())
Percentage of missing systolic BP data by set:
| Missing Systolic BP (in %) | |
|---|---|
| set | |
| train | 45.120914 |
| val | 48.290743 |
| test | 46.051965 |
Missing systolic BP data after imputation: SystolicBP_imputed 0 dtype: int64 Number of rows with SystolicBP_missing == 1 and non-missing SystolicBP_cleaned: 0 Number of rows with SystolicBP_missing == 0 and missing SystolicBP_cleaned: 0 Number of rows without missing SystolicBP but different SystolicBP value before and after cleaning: 0 Median Systolic BP by gender from training set:
| SystolicBP_cleaned | |
|---|---|
| Gender | |
| F | 124.0 |
| M | 124.0 |
Unique imputed SystolicBP by gender where SystolicBP_missing == 1:
| SystolicBP_imputed | |
|---|---|
| Gender | |
| F | [124.0] |
| M | [124.0] |
Diastolic Blood Pressure¶
# Q2a Step 16
# Check missingness of disatolic BP data by set
print("Percentage of missing disatolic BP data by set:")
display(visit_df_cleaned_split_v2.groupby('set')['DiastolicBP_cleaned'].apply(lambda x: x.isnull().sum() / len(x) * 100).rename('Missing Diastolic BP (in %)').reindex(['train', 'val', 'test']))
# Add indicator column for rows with missing disatolic BP data
visit_df_cleaned_split_v2['DiastolicBP_missing'] = visit_df_cleaned_split_v2['DiastolicBP_cleaned'].isnull().astype(int)
# Calculate median disatolic BP by gender for the training set
median_diastolicBP_by_gender = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['set'] == 'train'].groupby('Gender')['DiastolicBP_cleaned'].median()
# Impute missing disatolic BP values using the median by gender from the
# training set
visit_df_cleaned_split_v2['DiastolicBP_imputed'] = visit_df_cleaned_split_v2.apply(lambda row: median_diastolicBP_by_gender.get(row['Gender'], row['DiastolicBP_cleaned']) if pd.isnull(row['DiastolicBP_cleaned']) else row['DiastolicBP_cleaned'], axis = 1)
# Verify that there are no more missing disatolic BP data
print("\nMissing disatolic BP data after imputation:")
print(visit_df_cleaned_split_v2[['DiastolicBP_imputed']].isnull().sum())
# Verify that the indicator column was populated correctly
# Rows with DiastolicBP_missing == 1 should not have value in
# DiastolicBP_cleaned
print(f"\nNumber of rows with DiastolicBP_missing == 1 and non-missing DiastolicBP_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['DiastolicBP_missing'] == 1) & (visit_df_cleaned_split_v2['DiastolicBP_cleaned'].notnull())].shape[0]}")
# Rows with DiastolicBP_missing == 0 should have value in DiastolicBP_cleaned
print(f"\nNumber of rows with DiastolicBP_missing == 0 and missing DiastolicBP_cleaned: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['DiastolicBP_missing'] == 0) & (visit_df_cleaned_split_v2['DiastolicBP_cleaned'].isnull())].shape[0]}")
# Verify that DiastolicBP did not change for all rows without missing
# DiastolicBP
print(f"\nNumber of rows with missing DiastolicBP but different DiastolicBP value before and after cleaning: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['DiastolicBP_missing'] == 0) & (visit_df_cleaned_split_v2['DiastolicBP_imputed'] != visit_df_cleaned_split_v2['DiastolicBP_cleaned'])].shape[0]}")
# Get unique imputed DiastolicBP by gender where DiastolicBP_missing == 1 to
# verify that imputation was done correctly
print("\nMedian diastolic BP by gender from training set:")
display(median_diastolicBP_by_gender)
print("\nUnique imputed DiastolicBP by gender where DiastolicBP_missing == 1:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['DiastolicBP_missing'] == 1].groupby('Gender')['DiastolicBP_imputed'].unique())
Percentage of missing disatolic BP data by set:
| Missing Diastolic BP (in %) | |
|---|---|
| set | |
| train | 44.968197 |
| val | 48.223050 |
| test | 45.726279 |
Missing disatolic BP data after imputation: DiastolicBP_imputed 0 dtype: int64 Number of rows with DiastolicBP_missing == 1 and non-missing DiastolicBP_cleaned: 0 Number of rows with DiastolicBP_missing == 0 and missing DiastolicBP_cleaned: 0 Number of rows with missing DiastolicBP but different DiastolicBP value before and after cleaning: 0 Median diastolic BP by gender from training set:
| DiastolicBP_cleaned | |
|---|---|
| Gender | |
| F | 78.0 |
| M | 78.0 |
Unique imputed DiastolicBP by gender where DiastolicBP_missing == 1:
| DiastolicBP_imputed | |
|---|---|
| Gender | |
| F | [78.0] |
| M | [78.0] |
Q2b - Propose unit of analysis¶
Entity
The entity chosen for the unit of analysis will be each patient as referenced by the 'PatientGuid' in the data.
Why?
The entity is a patient and not, say, a visit or diagnosis, as this suits Betahelf's business context better and is also a more operationally feasible choice as expounded in further detail below:
Betahelf's management team wants to develop a healthcare program that proactively targets patients who are likely to receive an acute diagnosis over the next 12 months. As such, Betahelf is likely to call up or notify patients who are deemed to be at risk of receiving an acute diagnosis in the next year; they would not call up a "visit" or a "diagnosis". Thus, it is more appropriate that each row in the unit of analysis represents one patient at one point in time.
A visit- or diagnosis-indexed panel would scatter a patient across many rows in the same month, yet in practice, Betahelf's healthcare team would still likely only make one contact decision per patient who is deemed at risk of receiving an acute diagnosis over the next 12 months.
Choosing the patient as the entity avoids oversampling high utilisers. For example, with a visit-indexed panel, one patient with 10 visits in a month would appear 10x and dominate the training set while a quiet but high-risk member may appear only once. This distorts model learning and calibration.
Choosing the patient as the entity avoids duplicating "positives". In particular, a diagnosis-indexed panel would multiply label rows near an event of interest (i.e. same patient, same outcome), inflating apparent model performance and complicating evaluation.
In healthcare, features are naturally patient-centric as opposed to visit- or diagnosis-centric. For example, we would not expect a patient's height to vary by visit or diagnosis but, rather, across patients.
Timestamps
The basis of the timestamps chosen for the unit of analysis will be regularly-spaced timestamps. In particular, our as-at-date will be the end of each calendar month.
Why?
Similar to why the entity is a patient and not, say, a visit or diagnosis, regularly-spaced timestamps suits Betahelf's business context better and is also a more operationally feasible choice as expounded in further detail below:
Having regularly-spaced timestamps is much more realistic for Betahelf to administer in practice as it is not realistic to run the model each time a patient has an event of interest (e.g. a visit or lab result). It would be much more pragmatic for Betahelf to run the model once a month, at the end of each month, for all its current patients based on each patient's updated medical history, and contact the appropriate patients (i.e. those at risk of receiving an acute diagnosis over the next 12 months) where necessary.
Regularly-spaced timestamps aligns with the argument above that Betahelf should really be targeting patients as opposed to the events of interest themselves. As an extension to this argument, if the timestamps were event-driven and a patient had a few events of interest in a given month, where each event is predicting an acute diagnosis, Betahelf may end up notifying or targeting the same patient multiple times within that month. This adds noise to the model and risks duplicating positives around one outcome, which may be an inefficient use of Betahelf's resources.
Regularly-spaced timestamps allow for fair sampling and stable calibration. As alluded to above, event-driven timestamps oversamples high utilisers (i.e. many events of interest in some months) and drops quiet months (i.e. no events of interest) entirely. A model trained on such a distribution will not calibrate on the full population of patients that Betahelf is interested in predicting. In particular, regularly-spaced timestamps keep the denominator consistent; "no events of interest in a given month" becomes a learnable signal instead of disappearing entirely from the model.
We use end of calender month as opposed to start of calendar month for our as-at-dates simply because it is generally industry standard to run such models at the end of them month. It would be quite strange, in practice, to define the as-at-dates at the start of each month, which may also confuse users of the model and other stakeholders, such as Betahelf's management team, who are more likely to be used to seeing metrics defined at the end of a time period (e.g. end of each month).
The earliest as-at-date for any given patient is 12 months after that patient's earliest event of interest. This definition is consistent with our feature engineering later, where we will decide to use a patient's one-year medical history in the model as we believe such history carries significant information in predicting acute diagnosis over the next 12 months. For example, if we included rows for a patient from January 2020 while they only entered Betahelf's care in June 2020, the earlier rows for this patient will not have the required history. Using later data to backfill the missing history would be a form of data leakage.
Similarly, the latest as-at-date for any given patient is 12 months before that patient's latest event of interest. This ensures that there is no right-censoring of the response variable, which will look forward for the next 12 months. For example, if we included rows for a patient until December 2024 while their last observed data was, say, March 2023, we are not able to see what happens to this patient between April 2023 and December 2024. This effectively makes the response variable for this patient unknown, where marking it as 0 (i.e. no acute diagnosis over the next 12 months) undercounts positives and dropping the row introduces bias.
To summarise points 5 and 6, the restrictions imposed on the earliest and latest as-at-dates for each patient ensure that each row in the unit of analysis has a full history for features and a fully observed future for the response variable.
Notwithstanding the above explanations, we acknowledge that specific events of interest may carry useful information as predictors to whether a patient is likely to receive an acute diagnosis over the next 12 months. As such, in our feature engineering later, we will add an event-driven overlay to each row in the unit of analysis. This overlay will be a flag to indicate whether a patient has had an abnormal lab result in that given month.
Conclusion
In summary, for the unit of analysis, we propose to use patients, referenced by 'PatientGuid', as the entity and regularly-spaced timestamps with as-at-dates defined to be the end of each calendar month and relevant observation periods for each patient. This is an appropriate unit of analysis for Betahelf as it provides Betahelf with a timely, monthly, prioritised patient list that is operationally actionable, convenient, and auditable. In this way, Betahelf's management/healthcare team can maximise the use of their limited resources in an efficient manner.
# Q2b create a table of entities and timestamps
LOOKBACK_DAYS = 365
LABEL_WINDOW_DAYS = 365
# Function to gather patient-level time bounds from events of interest
def patient_time_bounds(*event_dfs):
rows = []
for df in event_dfs:
rows.append(df.groupby('PatientGuid')['Timestamp'].agg(['min', 'max']))
return pd.concat(rows).groupby(level = 0).agg({'min': 'min', 'max': 'max'}).reset_index()
# Gather bounds from diagnoses, visits, lab results, and prescriptions
bounds = patient_time_bounds(diagnosis_df_cleaned, visit_df_cleaned, labresult_df, prescription_df_cleaned_v2)
# Calculate snapshot window for each bound
# Snapshot start is earliest event date + 12 months as we are using a lookback
# period of 12 months
# Snapshot end is latest event date - 12 months as our response variable will
# look forward for 12 months
bounds['snapshot_start'] = (bounds['min'] + pd.Timedelta(days = LOOKBACK_DAYS)).dt.to_period('M').dt.end_time
bounds['snapshot_end'] = (bounds['max'] - pd.Timedelta(days = LABEL_WINDOW_DAYS)).dt.to_period('M').dt.end_time
snapshots = []
for _, r in bounds.iterrows():
snapshot = pd.date_range(r['snapshot_start'], r['snapshot_end'], freq = 'M') if r['snapshot_start'] < r['snapshot_end'] else []
if len(snapshot):
snapshots.append(pd.DataFrame({'PatientGuid': r['PatientGuid'], 'AsAtDate': snapshot}))
# Create unit_of_analysis_df and attach the splits obtained in the previous
# section
unit_of_analysis_df = pd.concat(snapshots, ignore_index = True).merge(patient_split, on = 'PatientGuid', how = 'left')
# Check shape of unit_of_analysis_df
print("Shape of unit_of_analysis_df:")
display(unit_of_analysis_df.shape)
# Verify that no patients have more than one unique value in the 'set' column
print(f"\nPatientGuids with more than one unique value in the 'set' column: {len(unit_of_analysis_df.groupby('PatientGuid')['set'].agg(n_unique_sets = lambda x: x.nunique(dropna = True)).query('n_unique_sets > 1'))}")
# Display a random sample of 5 rows from each of the training, validation, and
# testing sets to verify the structure of unit_of_analysis_df
print("\nSample of 5 rows from each set to verify structure of unit_of_analysis_df:")
display(unit_of_analysis_df.groupby('set').apply(lambda x: x.sample(5, random_state = random_seed)).reset_index(level = 'set', drop = True))
print("=" * 150)
# Verify that each patient has a row regardless of whether the patient has any
# event of interest in that month
print("\nVerify that patients have a row in unit_of_analysis_df regardless of whether they have any event of interest in that month:")
display(unit_of_analysis_df[(unit_of_analysis_df['PatientGuid'] == '7b1a82fe-4d2b-4164-b13d-2f70211c379f') & (unit_of_analysis_df['AsAtDate'].dt.year == 2021) & (unit_of_analysis_df['AsAtDate'].dt.month == 12)])
display(diagnosis_df_cleaned[(diagnosis_df_cleaned['PatientGuid'] == '7b1a82fe-4d2b-4164-b13d-2f70211c379f') & (diagnosis_df_cleaned['Timestamp'].dt.year == 2021) & (diagnosis_df_cleaned['Timestamp'].dt.month == 12)])
display(visit_df_cleaned[(visit_df_cleaned['PatientGuid'] == '7b1a82fe-4d2b-4164-b13d-2f70211c379f') & (visit_df_cleaned['Timestamp'].dt.year == 2021) & (visit_df_cleaned['Timestamp'].dt.month == 12)])
display(labresult_df[(labresult_df['PatientGuid'] == '7b1a82fe-4d2b-4164-b13d-2f70211c379f') & (labresult_df['Timestamp'].dt.year == 2021) & (labresult_df['Timestamp'].dt.month == 12)])
display(prescription_df_cleaned_v2[(prescription_df_cleaned_v2['PatientGuid'] == '7b1a82fe-4d2b-4164-b13d-2f70211c379f') & (prescription_df_cleaned_v2['Timestamp'].dt.year == 2021) & (prescription_df_cleaned_v2['Timestamp'].dt.month == 12)])
print("=" * 150)
# Verify that the appropriate rows are created for each patient (i.e. no rows
# outside of earliest event + 12 months and latest event - 12 months and monthly
# rows in between)
sample_patient = '7b1a82fe-4d2b-4164-b13d-2f70211c379f'
print("\nVerify that the appropriate rows are created for each patient")
print("(i.e. no rows outside of earliest event + 12 months and latest event - 12 months and monthly rows in between, only appear in one set):")
print(f"Earliest AsAtDate: {unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == sample_patient]['AsAtDate'].min()}")
print(f"Latest AsAtDate: {unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == sample_patient]['AsAtDate'].max()}")
print(f"Number of rows: {len(unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == sample_patient])}")
print(f"Minimum Timestamp: {pd.concat([diagnosis_df_cleaned[diagnosis_df_cleaned['PatientGuid'] == sample_patient], visit_df_cleaned[visit_df_cleaned['PatientGuid'] == sample_patient], labresult_df[labresult_df['PatientGuid'] == sample_patient], prescription_df_cleaned_v2[prescription_df_cleaned_v2['PatientGuid'] == sample_patient]
])['Timestamp'].min()}")
print(f"Maximum Timestamp: {pd.concat([diagnosis_df_cleaned[diagnosis_df_cleaned['PatientGuid'] == sample_patient], visit_df_cleaned[visit_df_cleaned['PatientGuid'] == sample_patient], labresult_df[labresult_df['PatientGuid'] == sample_patient], prescription_df_cleaned_v2[prescription_df_cleaned_v2['PatientGuid'] == sample_patient]
])['Timestamp'].max()}")
print(f"Unique values in the 'set' column: {unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == sample_patient]['set'].unique()}")
Shape of unit_of_analysis_df:
(98704, 3)
PatientGuids with more than one unique value in the 'set' column: 0 Sample of 5 rows from each set to verify structure of unit_of_analysis_df:
| PatientGuid | AsAtDate | set | |
|---|---|---|---|
| 97175 | fb6243fe-dc5c-4b48-978a-c8e640962519 | 2023-07-31 23:59:59.999999999 | test |
| 87296 | e3d67627-846c-4171-9b49-bab443263706 | 2023-07-31 23:59:59.999999999 | test |
| 68054 | b346bd6f-9f4a-443a-afd5-fadcb7133324 | 2022-12-31 23:59:59.999999999 | test |
| 87248 | e3c52817-5926-4f14-8ccf-772642e4dd21 | 2021-12-31 23:59:59.999999999 | test |
| 31482 | 51f21281-02e5-4fd9-92bc-a7f855debb4b | 2021-10-31 23:59:59.999999999 | test |
| 64753 | aa7e6edc-79ec-40d6-b5f9-d7383d16d6d7 | 2023-09-30 23:59:59.999999999 | train |
| 39116 | 67aef923-24e6-49db-8841-76d872503a3a | 2023-09-30 23:59:59.999999999 | train |
| 29837 | 4d960ae5-0f30-45e0-b430-5f76305372cc | 2022-04-30 23:59:59.999999999 | train |
| 98611 | ff9316c2-6ac6-421e-9a1a-38d94517094b | 2022-03-31 23:59:59.999999999 | train |
| 18581 | 31d54648-7044-4670-8fc5-730832f2447d | 2023-06-30 23:59:59.999999999 | train |
| 10750 | 1da9f3af-7237-42a2-9071-67fc447e77a2 | 2022-03-31 23:59:59.999999999 | val |
| 25925 | 426f1614-7730-4a03-8b88-f5fc651c841c | 2021-02-28 23:59:59.999999999 | val |
| 65223 | abe5d18f-898a-4795-bd23-2b4ddcbde650 | 2022-01-31 23:59:59.999999999 | val |
| 72165 | bd9a05c0-a603-4e20-8131-a3a5dff38653 | 2023-10-31 23:59:59.999999999 | val |
| 9511 | 1af3acc1-94c7-48e9-82be-674d5fc1c41b | 2021-12-31 23:59:59.999999999 | val |
====================================================================================================================================================== Verify that patients have a row in unit_of_analysis_df regardless of whether they have any event of interest in that month:
| PatientGuid | AsAtDate | set | |
|---|---|---|---|
| 46205 | 7b1a82fe-4d2b-4164-b13d-2f70211c379f | 2021-12-31 23:59:59.999999999 | train |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned |
|---|
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | Temperature | PhysicianSpecialty | Height_cm | Height_cleaned | SystolicBP_cleaned | DiastolicBP_cleaned |
|---|
| LabResultGuid | PatientGuid | Timestamp | tz_offset |
|---|
| PrescriptionGuid | PatientGuid | Timestamp | tz_offset | Quantity | NumberOfRefills | RefillAsNeeded | GenericAllowed | NdcCode | MedicationName | MedicationStrength | Schedule |
|---|
====================================================================================================================================================== Verify that the appropriate rows are created for each patient (i.e. no rows outside of earliest event + 12 months and latest event - 12 months and monthly rows in between, only appear in one set): Earliest AsAtDate: 2020-12-31 23:59:59.999999999 Latest AsAtDate: 2023-12-31 23:59:59.999999999 Number of rows: 37 Minimum Timestamp: 2020-01-01 18:13:29 Maximum Timestamp: 2024-12-06 19:50:38 Unique values in the 'set' column: ['train']
Q2c - Construct response variable¶
Q2c Step 1¶
# Q2c Step 1
# Function to create response variable
def build_response(uoa: pd.DataFrame, dx: pd.DataFrame, label_window_days: int = 365) -> pd.DataFrame:
"""
Construct AcuteDiagnosis for each row in unit_of_analysis_df.
Definitions:
p = patient
S = as-at-date
T = timestamp
AcuteDiagnosis(p, S) = 1 iff there exists acute diagnosis for a particular ICD-9 code at time T in (S, S + 365] such that the same code was not in (S - 365, S].
Implementation details:
- "New" acute events per (PatientGuid, ICD9Code) defined as first occurrence OR >365d since previous same code.
- For each new event at time T:
start_ts = max(T - 365d, P + 365d) if there was a previous same-code at time P else T - 365d
start_m = month end of start_ts
end_m = previous month end of T (i.e. exclude event month)
- Paint label = 1 for all as-at-dates in [start_m, end_m].
"""
uoa = uoa.copy()
uoa['AcuteDiagnosis'] = 0
# Filter diagnosis_df_cleaned for only rows with acute diagnosis
dx = dx[dx['Acute']].copy()
# Identify "new" acute events per patient + ICD-9 code
L365 = pd.Timedelta(days = label_window_days)
dx['prev_time'] = dx.sort_values(['PatientGuid', 'ICD9Code_cleaned', 'Timestamp']).groupby(['PatientGuid', 'ICD9Code_cleaned'])['Timestamp'].shift(1)
dx['is_new'] = dx['prev_time'].isna() | ((dx['Timestamp'] - dx['prev_time']) > L365)
# Check that there are indeed "new" acute events
if dx[dx['is_new']].empty:
return uoa
# Paint positives for relevant as-at-dates
paint_rows = []
for pid, grp in dx[dx['is_new'] == True].groupby('PatientGuid'):
for _, row in grp.iterrows():
T = pd.Timestamp(row['Timestamp'])
if pd.notna(row['prev_time']):
P = row['prev_time']
else:
P = pd.NaT
if pd.notna(P):
start_ts = max(T - L365, pd.Timestamp(P) + L365)
else:
start_ts = T - L365
# Get 'period end' timestamp of current month
start_m = pd.Timestamp(start_ts).to_period('M').end_time
# Get 'period end' timestamp of previous month
end_m = (pd.Timestamp(T).to_period('M') - 1).end_time
if start_m <= end_m:
# Build month-ends as 'period end' timestamps
pr = pd.period_range(start = start_m.to_period('M'), end = end_m.to_period('M'), freq = 'M').to_timestamp(how = 'end')
paint_rows.append(pd.DataFrame({'PatientGuid': pid, 'AsAtDate': pr, 'label_hit': 1}))
if paint_rows:
label_hits = pd.concat(paint_rows, ignore_index = True).drop_duplicates(['PatientGuid', 'AsAtDate'])
uoa = uoa.merge(label_hits, on = ['PatientGuid', 'AsAtDate'], how = 'left')
uoa['AcuteDiagnosis'] = uoa['label_hit'].fillna(0).astype(int)
uoa.drop(columns = ['label_hit'], inplace = True, errors = 'ignore')
return uoa
# Q2c add the response variable values to the table you created in Q2b
unit_of_analysis_df = build_response(unit_of_analysis_df, diagnosis_df_cleaned)
# Display first few rows of unit_of_analysis_df to show created response
# variable column - 'AcuteDiagnosis'
display(unit_of_analysis_df.head())
| PatientGuid | AsAtDate | set | AcuteDiagnosis | |
|---|---|---|---|---|
| 0 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-01-31 23:59:59.999999999 | test | 0 |
| 1 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-02-28 23:59:59.999999999 | test | 0 |
| 2 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-03-31 23:59:59.999999999 | test | 0 |
| 3 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-04-30 23:59:59.999999999 | test | 0 |
| 4 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-05-31 23:59:59.999999999 | test | 0 |
Q2c Step 2 - Check 1¶
# Q2c Step 2
# Get counts of 0s and 1s for each set (including NaNs)
acute_diagnosis_counts = unit_of_analysis_df.groupby('set')['AcuteDiagnosis'].value_counts().unstack()
# Calculate the percentage of 1s for each set
percentage_ones = unit_of_analysis_df.groupby('set')['AcuteDiagnosis'].value_counts(normalize = True).unstack()[1] * 100
# Merge counts and percentage and display as a single DataFrame
print("Summary of AcuteDiagnosis for each set:")
display(acute_diagnosis_counts.merge(percentage_ones.rename('% of 1s'), left_index = True, right_index = True).reindex(['train', 'val', 'test']))
Summary of AcuteDiagnosis for each set:
| 0 | 1 | % of 1s | |
|---|---|---|---|
| set | |||
| train | 53410 | 16745 | 23.868577 |
| val | 10101 | 3330 | 24.793388 |
| test | 11457 | 3661 | 24.216166 |
Q2c Step 3 - Check 2¶
# Q2c Step 3
# Find a patient where AcuteDiagnosis = 1
display(unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == '00fa9de1-2ea2-44bd-be32-80499af79e97'].sort_values(by = 'AsAtDate'))
display(diagnosis_df_cleaned[diagnosis_df_cleaned['PatientGuid'] == '00fa9de1-2ea2-44bd-be32-80499af79e97'].sort_values(by = 'Timestamp'))
| PatientGuid | AsAtDate | set | AcuteDiagnosis | |
|---|---|---|---|---|
| 108 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2021-12-31 23:59:59.999999999 | train | 0 |
| 109 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-01-31 23:59:59.999999999 | train | 0 |
| 110 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-02-28 23:59:59.999999999 | train | 0 |
| 111 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-03-31 23:59:59.999999999 | train | 0 |
| 112 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-04-30 23:59:59.999999999 | train | 0 |
| 113 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-05-31 23:59:59.999999999 | train | 0 |
| 114 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-06-30 23:59:59.999999999 | train | 0 |
| 115 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-07-31 23:59:59.999999999 | train | 0 |
| 116 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-08-31 23:59:59.999999999 | train | 0 |
| 117 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-09-30 23:59:59.999999999 | train | 0 |
| 118 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-10-31 23:59:59.999999999 | train | 0 |
| 119 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-11-30 23:59:59.999999999 | train | 1 |
| 120 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2022-12-31 23:59:59.999999999 | train | 1 |
| 121 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-01-31 23:59:59.999999999 | train | 1 |
| 122 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-02-28 23:59:59.999999999 | train | 1 |
| 123 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-03-31 23:59:59.999999999 | train | 1 |
| 124 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-04-30 23:59:59.999999999 | train | 1 |
| 125 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-05-31 23:59:59.999999999 | train | 1 |
| 126 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-06-30 23:59:59.999999999 | train | 1 |
| 127 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-07-31 23:59:59.999999999 | train | 1 |
| 128 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-08-31 23:59:59.999999999 | train | 1 |
| 129 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-09-30 23:59:59.999999999 | train | 1 |
| 130 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-10-31 23:59:59.999999999 | train | 1 |
| 131 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-11-30 23:59:59.999999999 | train | 0 |
| 132 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-12-31 23:59:59.999999999 | train | 0 |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned | |
|---|---|---|---|---|---|---|---|---|
| 19117 | ba4ce640-7f7a-42fe-8786-570dd10fc993 | 00fa9de1-2ea2-44bd-be32-80499af79e97 | 2023-11-25 15:38:47 | -06:00 | 7823 | Edema | True | 782 |
Q2c Step 4 - Check 3¶
# Q2c Step 4
# Find a patient where AcuteDiagnosis = 0 but original diagnosis was acute (due
# to the same diagnosis occurring within 12 months prior as defined by the
# cleaned ICD-9 codes)
display(unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == '416314ac-529a-4de8-bdc3-1a79b9125784'].sort_values(by = 'AsAtDate'))
display(diagnosis_df_cleaned[diagnosis_df_cleaned['PatientGuid'] == '416314ac-529a-4de8-bdc3-1a79b9125784'].sort_values(by = 'Timestamp'))
| PatientGuid | AsAtDate | set | AcuteDiagnosis | |
|---|---|---|---|---|
| 25449 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2021-12-31 23:59:59.999999999 | train | 0 |
| 25450 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-01-31 23:59:59.999999999 | train | 0 |
| 25451 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-02-28 23:59:59.999999999 | train | 0 |
| 25452 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-03-31 23:59:59.999999999 | train | 1 |
| 25453 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-04-30 23:59:59.999999999 | train | 1 |
| 25454 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-05-31 23:59:59.999999999 | train | 1 |
| 25455 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-06-30 23:59:59.999999999 | train | 1 |
| 25456 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-07-31 23:59:59.999999999 | train | 1 |
| 25457 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-08-31 23:59:59.999999999 | train | 1 |
| 25458 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-09-30 23:59:59.999999999 | train | 1 |
| 25459 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-10-31 23:59:59.999999999 | train | 1 |
| 25460 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-11-30 23:59:59.999999999 | train | 1 |
| 25461 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-12-31 23:59:59.999999999 | train | 1 |
| 25462 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-01-31 23:59:59.999999999 | train | 1 |
| 25463 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-02-28 23:59:59.999999999 | train | 1 |
| 25464 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-03-31 23:59:59.999999999 | train | 0 |
| 25465 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-04-30 23:59:59.999999999 | train | 0 |
| 25466 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-05-31 23:59:59.999999999 | train | 0 |
| 25467 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-06-30 23:59:59.999999999 | train | 0 |
| 25468 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-07-31 23:59:59.999999999 | train | 0 |
| 25469 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-08-31 23:59:59.999999999 | train | 0 |
| 25470 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-09-30 23:59:59.999999999 | train | 0 |
| 25471 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-10-31 23:59:59.999999999 | train | 0 |
| 25472 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-11-30 23:59:59.999999999 | train | 0 |
| 25473 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-12-31 23:59:59.999999999 | train | 0 |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned | |
|---|---|---|---|---|---|---|---|---|
| 20021 | c38a86ad-9c0c-4953-a1de-795a746ea120 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-04-20 18:42:31 | -06:00 | 1121 | Candidiasis of vulva and vagina | False | 112 |
| 18599 | b52afb08-4196-46da-a806-9a9cbdcdccec | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-06-14 21:31:45 | -06:00 | 53081 | Esophageal reflux | False | 530 |
| 9418 | 5aa718d4-fba3-4cf8-a9a1-33b0a6b5e5ae | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-06-25 17:05:32 | -06:00 | 7242 | Lumbago | False | 724 |
| 14605 | 8cb0730e-b803-49c0-b0f1-a5d1e967807d | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-07-27 21:28:08 | -06:00 | 78093 | Memory loss NOS | False | 780 |
| 11426 | 6e8c8bce-59ea-484b-b7a2-44c7222563fe | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-08-22 18:46:57 | -06:00 | 4019 | Unspecified essential hypertension | False | 401 |
| 9476 | 5b40d749-bfe0-4e52-93f7-f4a320411ce3 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-10-02 15:30:43 | -06:00 | 2724 | Other and unspecified hyperlipidemia | False | 272 |
| 11181 | 6c2f7893-cd94-4575-90f5-be916645a7c5 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2022-12-11 20:33:12 | -06:00 | 30272 | Psychosexual dysfunction with inhibited sexual... | False | 302 |
| 17465 | a96d04e5-5bd6-4671-9510-dfcc9676d76e | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-03-31 20:37:44 | -06:00 | 78060 | Fever, unspecified | True | 780 |
| 13772 | 84b4fb5c-fd36-414f-8418-97804542028d | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2023-10-24 18:01:00 | -06:00 | 054 | Herpes simplex | False | 054 |
| 19892 | c21bcb6b-79d5-4c79-b0a0-beb36e1dd4f2 | 416314ac-529a-4de8-bdc3-1a79b9125784 | 2024-01-22 21:56:31 | -06:00 | 78051 | Insomnia with sleep apnea, unspecified | True | 780 |
Q2c Step 5 - Check 4¶
# Q2c Step 5
# Find a patient with AcuteDiagnosis = 1 but not for the entire 12 month window
# (due to the same diagnosis occuring slightly more then 12 months prior as
# defined by the cleaned ICD-9 codes))
display(unit_of_analysis_df[unit_of_analysis_df['PatientGuid'] == '38c9ef86-c4e4-41e7-98a0-e1a42cf2106a'].sort_values(by = 'AsAtDate'))
display(diagnosis_df_cleaned[diagnosis_df_cleaned['PatientGuid'] == '38c9ef86-c4e4-41e7-98a0-e1a42cf2106a'].sort_values(by = 'Timestamp'))
| PatientGuid | AsAtDate | set | AcuteDiagnosis | |
|---|---|---|---|---|
| 21743 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-04-30 23:59:59.999999999 | train | 0 |
| 21744 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-05-31 23:59:59.999999999 | train | 0 |
| 21745 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-06-30 23:59:59.999999999 | train | 0 |
| 21746 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-07-31 23:59:59.999999999 | train | 0 |
| 21747 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-08-31 23:59:59.999999999 | train | 0 |
| 21748 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-09-30 23:59:59.999999999 | train | 0 |
| 21749 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-10-31 23:59:59.999999999 | train | 0 |
| 21750 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-11-30 23:59:59.999999999 | train | 0 |
| 21751 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-12-31 23:59:59.999999999 | train | 1 |
| 21752 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-01-31 23:59:59.999999999 | train | 1 |
| 21753 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-02-28 23:59:59.999999999 | train | 1 |
| 21754 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-03-31 23:59:59.999999999 | train | 0 |
| 21755 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-04-30 23:59:59.999999999 | train | 0 |
| 21756 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-05-31 23:59:59.999999999 | train | 0 |
| 21757 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-06-30 23:59:59.999999999 | train | 0 |
| 21758 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-07-31 23:59:59.999999999 | train | 0 |
| 21759 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-08-31 23:59:59.999999999 | train | 0 |
| 21760 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-09-30 23:59:59.999999999 | train | 0 |
| 21761 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-10-31 23:59:59.999999999 | train | 0 |
| 21762 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-11-30 23:59:59.999999999 | train | 0 |
| 21763 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-12-31 23:59:59.999999999 | train | 0 |
| 21764 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-01-31 23:59:59.999999999 | train | 0 |
| 21765 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-02-28 23:59:59.999999999 | train | 0 |
| 21766 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-03-31 23:59:59.999999999 | train | 0 |
| 21767 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-04-30 23:59:59.999999999 | train | 0 |
| 21768 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-05-31 23:59:59.999999999 | train | 0 |
| 21769 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-06-30 23:59:59.999999999 | train | 0 |
| 21770 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-07-31 23:59:59.999999999 | train | 0 |
| 21771 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-08-31 23:59:59.999999999 | train | 0 |
| 21772 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-09-30 23:59:59.999999999 | train | 0 |
| 21773 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-10-31 23:59:59.999999999 | train | 0 |
| 21774 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-11-30 23:59:59.999999999 | train | 0 |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned | |
|---|---|---|---|---|---|---|---|---|
| 11985 | 7413d253-5510-4d06-9cdf-05e2fb6f0e2e | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2020-12-15 15:07:14 | -06:00 | 5651 | Anal fistula | False | 565 |
| 2427 | 16f1fcaf-9ef2-4462-9c6c-3db5dc6e18a4 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2020-12-26 19:17:30 | -06:00 | 7030 | Ingrowing nail | True | 703 |
| 18869 | b7e932ec-8c01-44b6-bd1b-868e3c8b7140 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2021-09-30 19:06:30 | -06:00 | 7295 | Pain in limb | False | 729 |
| 11731 | 71ad7b60-ab64-4343-88b0-c3dbd89ece48 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-03-09 17:27:34 | -06:00 | 7823 | Edema | False | 782 |
| 10739 | 67da02f3-9227-4b5f-9f90-0476e8cf12ae | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-03-29 21:22:42 | -06:00 | 7038 | Other specified diseases of nail | True | 703 |
| 12302 | 772b63a1-5a0c-4345-a044-f73a6412e947 | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2022-06-06 17:46:36 | -06:00 | 6811 | Cellulitis and abscess of toe | False | 681 |
| 2305 | 15aeb5b1-f4f6-4bef-b9a2-436a1022fb4f | 38c9ef86-c4e4-41e7-98a0-e1a42cf2106a | 2023-07-10 19:38:54 | -06:00 | 7245 | Backache, unspecified | False | 724 |
Q2d Suggest four metrics¶
Q2e - Construct neural network¶
Q2e Step 1 - Feature Engineering¶
Creation of Age-Last-Birthday as of As-At-Date¶
# Q2e Step 1
# Count the number of rows of unique PatientGuid and DateOfBirth combinations
dob_counts = patient_df_cleaned.groupby('PatientGuid')['DateOfBirth'].nunique().reset_index(name = 'count')
# Display PatientGuids where the count is greater than 1
multiple_dob = dob_counts[dob_counts['count'] > 1]
if not multiple_dob.empty:
print("PatientGuid and DateOfBirth combinations appearing more than once:")
display(multiple_dob)
else:
print("No duplicate PatientGuid and DateOfBirth combinations found.")
# Get dob information into unit of analysis
# No need to use validity periods since each patient only has one dob and we do
# not expect dob to be time-varying
unit_of_analysis_df_v2 = pd.merge(unit_of_analysis_df, patient_df_cleaned[['PatientGuid', 'DateOfBirth']].drop_duplicates(), on = 'PatientGuid', how = 'left')
# Calculate age for each observation as of as-at-date
unit_of_analysis_df_v2['AgeLastBirthday'] = (pd.to_datetime(unit_of_analysis_df_v2['AsAtDate']) - pd.to_datetime(unit_of_analysis_df_v2['DateOfBirth'])).dt.days // 365
# Drop DateOfBirth from unit_of_analysis_df_v2
unit_of_analysis_df_v2 = unit_of_analysis_df_v2.drop(columns = ['DateOfBirth'])
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after calculation of age:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Confirm that all observations have age
print("\nRows in unit_of_analysis_df_v2 where age is empty:")
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['AgeLastBirthday'].isna()])
# Display the first few observations to verify the age calculation
display(unit_of_analysis_df_v2.head())
display(patient_df_cleaned[patient_df_cleaned['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49'])
No duplicate PatientGuid and DateOfBirth combinations found. Shape of unit_of_analysis_df before and after calculation of age:
(98704, 5)
(98704, 4)
Rows in unit_of_analysis_df_v2 where age is empty:
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday |
|---|
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | |
|---|---|---|---|---|---|
| 0 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-01-31 23:59:59.999999999 | test | 0 | 26 |
| 1 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-02-28 23:59:59.999999999 | test | 0 | 26 |
| 2 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-03-31 23:59:59.999999999 | test | 0 | 27 |
| 3 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-04-30 23:59:59.999999999 | test | 0 | 27 |
| 4 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-05-31 23:59:59.999999999 | test | 0 | 27 |
| RowID | ValidFrom | ValidTo | PatientGuid | Gender | DateOfBirth | StateCode | BloodType | |
|---|---|---|---|---|---|---|---|---|
| 0 | 6519cb4e-8bb0-463f-80bd-e903b538ef40 | 2020-01-01 | 2099-12-31 23:59:59 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | M | 1994-03-27 | MI | A+ |
Creation of Number of Unique ICD-9 "Group1" Descriptions (12-Month History)¶
# Q2e Step 1
def add_icd9_12m_uniques(uoa: pd.DataFrame, diagnosis: pd.DataFrame, icd9: pd.DataFrame) -> pd.DataFrame:
"""
For each (PatientGuid, AsAtDate) in unit_of_analysis_df, compute:
- unique_ICD9CodesDesc_p12m: # unique Group1 descriptions in (AsAtDate-365d, AsAtDate]
"""
# Join icd9 to bring Group1 into diagnosis
dx = diagnosis.merge(icd9[['ICD9Code_cleaned', 'Group1']], on = 'ICD9Code_cleaned', how = 'left')
# Sort for per-patient binary searches
dx = dx.sort_values(['PatientGuid', 'Timestamp']).reset_index(drop = True)
uoa = uoa.sort_values(['PatientGuid', 'AsAtDate']).reset_index(drop = True)
# Prepare result arrays aligned to uoa rows
uniq_groups = np.zeros(len(uoa), dtype = np.int32)
# Group both frames by patient for efficient windowed counts
dx_grp = {pid: g for pid, g in dx.groupby('PatientGuid', sort = False)}
uoa_grp = uoa.groupby('PatientGuid', sort = False)
# Loop patients (fast enough because snapshots are monthly)
for pid, u in uoa_grp:
if pid not in dx_grp:
# No diagnoses for this patient -> zeros
continue
d = dx_grp[pid]
times = d['Timestamp'].to_numpy()
groups = d['Group1'].to_numpy()
# For each AsAtDate, find indices of events in (S - 365, S]
idxs = u.index.to_numpy()
asofs = u['AsAtDate'].to_numpy()
for i, S in zip(idxs, asofs):
if pd.isna(S):
uniq_groups[i] = 0
continue
start = S - np.timedelta64(365, 'D')
# Left bound: first event strictly > start (exclude exactly S - 365)
L = np.searchsorted(times, start, side = 'right')
# Right bound: last event <= S (inclusive)
R = np.searchsorted(times, S, side = 'right')
if L >= R:
uniq_groups[i] = 0
else:
# Get number of unique Group1 descriptions
uniq_groups[i] = pd.unique(groups[L:R]).size
# Attach results to original frame
# The index from the sorted uoa is used to align the results
out = uoa.copy()
out['unique_ICD9CodesDesc_p12m'] = uniq_groups
# Preserve original row order
original_order_uoa = uoa.sort_index().copy()
out = out.set_index(original_order_uoa.index)
return out
# Attach columns to unit of analysis
unit_of_analysis_df_v2 = add_icd9_12m_uniques(unit_of_analysis_df_v2, diagnosis_df_cleaned, icd9_df_cleaned_v2)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inserting unique counts of ICD-9 codes and their descriptions:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Check for any missing counts
print(f"\nNumber of rows in unit_of_analysis_df_v2 where the number of unique ICD-9 code descriptions is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['unique_ICD9CodesDesc_p12m'].isna()].shape[0]}")
# Sample check a patient at a particular as-at-date to verify that the number of
# unique ICD-9 Group1 descriptions is correctly calculated
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2021-01-31').date())])
display(diagnosis_df_cleaned[(diagnosis_df_cleaned['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (diagnosis_df_cleaned['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())].sort_values(by = 'Timestamp'))
display(icd9_df_cleaned_v2[icd9_df_cleaned_v2['ICD9Code_cleaned'].isin(['272', '268', '401'])])
Shape of unit_of_analysis_df before and after inserting unique counts of ICD-9 codes and their descriptions:
(98704, 6)
(98704, 4)
Number of rows in unit_of_analysis_df_v2 where the number of unique ICD-9 code descriptions is empty: 0
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | |
|---|---|---|---|---|---|---|
| 0 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-01-31 23:59:59.999999999 | test | 0 | 26 | 2 |
| DiagnosisGuid | PatientGuid | Timestamp | tz_offset | ICD9Code | DiagnosisDescription | Acute | ICD9Code_cleaned | |
|---|---|---|---|---|---|---|---|---|
| 6625 | 3eed7773-1cfd-494d-b2f3-c1c9ecab4e43 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-01-18 16:19:28 | -05:00 | 2449 | Unspecified hypothyroidism | False | 244 |
| 6478 | 3d91ab77-647d-4412-935e-b43cc19c8f51 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-02-19 14:18:37 | -05:00 | 2722 | Mixed hyperlipidemia | False | 272 |
| 13362 | 80c9cd6c-a00d-4024-b331-897c765147be | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-11-05 21:20:25 | -05:00 | 2689 | Unspecified vitamin D deficiency | False | 268 |
| 18256 | b1f95ef9-1e8f-4d41-bcd4-25199265b150 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-11-18 14:15:58 | -05:00 | 4011 | Benign essential hypertension | False | 401 |
| ICD9Code_cleaned | Group1 | Group2 | Group3 | |
|---|---|---|---|---|
| 243 | 268 | Endocrine, Nutritional And Metabolic Diseases,... | Nutritional Deficiencies | Vitamin d deficiency |
| 247 | 272 | Endocrine, Nutritional And Metabolic Diseases,... | Other Metabolic Disorders And Immunity Disorders | Disorders of lipoid metabolism |
| 372 | 401 | Diseases Of The Circulatory System | Hypertensive Disease | Essential hypertension |
Re-creation of BMI¶
# Q2e Step 1
# BMI appears unreliable from approximate distribution
# Define bins and labels to group BMI
bmi_bins = [0, 15, 20, 25, 30, 35, 40, 45, 50, float('inf')]
bmi_labels = ["<=15", "16 to 20", "21 to 25", "26 to 30", "31 to 35", "36 to 40", "41 to 45", "46 to 50", ">=51"]
# Display counts of BMI by group
print("Count of BMI grouped by bins:")
display(pd.cut(visit_df_cleaned_split_v2['BMI'], bins = bmi_bins, labels = bmi_labels, right = True).value_counts().sort_index())
# Example: Assuming height is in inches and weight is in pounds, data in index 16466 suggests a calculated BMI of roughly 42.5 whereas the data shows a BMI of roughly 53.
print("Example observation where recorded BMI and calculated BMI (from height and weight) are inconsistent:")
display(visit_df_cleaned_split_v2.loc[[16466]])
# Solution: Easier to deal with height and weight outliers as we are more familiar with those measurements, so clean height and weight and recalculate BMI.
# Calculate BMI using the formula: BMI = (Weight in pounds / (Height in inches)^2) * 703
visit_df_cleaned_split_v2['BMI_recalc'] = (visit_df_cleaned_split_v2['Weight'] / (visit_df_cleaned_split_v2['Height'] ** 2)) * 703
visit_df_cleaned_split_v2['BMI_cleaned'] = (visit_df_cleaned_split_v2['Weight_imputed'] / (visit_df_cleaned_split_v2['Height_imputed'] ** 2)) * 703
# Display the first few rows to verify the new column
display(visit_df_cleaned_split_v2.head())
# Verify that there are no blanks in BMI_cleaned
print("\nRows in visit_df_cleaned_split_v2 where BMI_cleaned is empty:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['BMI_cleaned'].isna()])
# Display counts of BMI by group
# Close to 80% of observations are in the 21 to 35 range, which is a sign that
# our calculations are reasonable
bmi_cleaned_counts = pd.cut(visit_df_cleaned_split_v2['BMI_cleaned'], bins = bmi_bins, labels = bmi_labels, right = True).value_counts().sort_index()
bmi_cleaned_percentages = pd.cut(visit_df_cleaned_split_v2['BMI_cleaned'], bins = bmi_bins, labels = bmi_labels, right = True).value_counts(normalize = True).sort_index() * 100
print("Count and percentage of BMI_cleaned grouped by bins:")
display(pd.DataFrame({'Count': bmi_cleaned_counts, 'Percentage': bmi_cleaned_percentages}))
# Use BMI instead of height and weight as using all three features may
# result in multi-collinearity issues since BMI is directly calculated from
# height and weight and keeps the model more parsimonious
Count of BMI grouped by bins:
| count | |
|---|---|
| BMI | |
| <=15 | 6 |
| 16 to 20 | 560 |
| 21 to 25 | 5966 |
| 26 to 30 | 14480 |
| 31 to 35 | 6892 |
| 36 to 40 | 2398 |
| 41 to 45 | 635 |
| 46 to 50 | 167 |
| >=51 | 56 |
Example observation where recorded BMI and calculated BMI (from height and weight) are inconsistent:
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | ... | set | Height_missing | Gender | Height_imputed | Weight_missing | Weight_imputed | SystolicBP_missing | SystolicBP_imputed | DiastolicBP_missing | DiastolicBP_imputed | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 16466 | 35f84d8e-1363-4ec6-8e0a-0f9de889f162 | 2fc24822-6a8b-4053-8738-71271647667e | 2020-11-13 20:29:02 | -05:00 | 71.0 | 305.0 | 53.285 | 125.0 | 90.0 | 14.0 | ... | train | 0 | F | 71.0 | 0 | 305.0 | 0 | 125.0 | 0 | 90.0 |
1 rows × 26 columns
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | ... | Gender | Height_imputed | Weight_missing | Weight_imputed | SystolicBP_missing | SystolicBP_imputed | DiastolicBP_missing | DiastolicBP_imputed | BMI_recalc | BMI_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | e3f9ca98-8201-4d45-9271-86ed19697f7a | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-01-13 21:43:58 | -05:00 | NaN | NaN | NaN | NaN | NaN | NaN | ... | M | 67.5 | 1 | 191.25 | 1 | 124.0 | 1 | 78.0 | NaN | 29.508642 |
| 1 | 70f3664e-c866-41db-a6fa-7737d7956b25 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-02-24 18:23:23 | -05:00 | 69.0 | 172.2 | 25.681 | 137.0 | 90.0 | NaN | ... | M | 69.0 | 0 | 172.20 | 0 | 137.0 | 0 | 90.0 | 25.426717 | 25.426717 |
| 2 | 67e314de-25a8-4187-bd98-c57c75b6aa14 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-04-24 13:21:27 | -05:00 | NaN | NaN | NaN | 140.0 | 72.0 | NaN | ... | M | 67.5 | 1 | 191.25 | 0 | 140.0 | 0 | 72.0 | NaN | 29.508642 |
| 3 | 11b534ce-3b18-4ebf-bdeb-9a1fab98740e | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-07-02 16:42:13 | -05:00 | NaN | NaN | NaN | 157.0 | 88.0 | NaN | ... | M | 67.5 | 1 | 191.25 | 0 | 157.0 | 0 | 88.0 | NaN | 29.508642 |
| 4 | 8a108567-fd09-40cf-8323-23e9e23ba9e4 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-07-11 16:03:27 | -05:00 | NaN | NaN | NaN | 138.0 | 85.0 | NaN | ... | M | 67.5 | 1 | 191.25 | 0 | 138.0 | 0 | 85.0 | NaN | 29.508642 |
5 rows × 28 columns
Rows in visit_df_cleaned_split_v2 where BMI_cleaned is empty:
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | ... | Gender | Height_imputed | Weight_missing | Weight_imputed | SystolicBP_missing | SystolicBP_imputed | DiastolicBP_missing | DiastolicBP_imputed | BMI_recalc | BMI_cleaned |
|---|
0 rows × 28 columns
Count and percentage of BMI_cleaned grouped by bins:
| Count | Percentage | |
|---|---|---|
| BMI_cleaned | ||
| <=15 | 36 | 0.040381 |
| 16 to 20 | 1900 | 2.131216 |
| 21 to 25 | 9286 | 10.416036 |
| 26 to 30 | 58008 | 65.067133 |
| 31 to 35 | 12449 | 13.963949 |
| 36 to 40 | 5126 | 5.749795 |
| 41 to 45 | 1949 | 2.186179 |
| 46 to 50 | 325 | 0.364550 |
| >=51 | 72 | 0.080762 |
Creation of Mean Arterial Pressure¶
# Q2e Step 1
# Calculate MAP using the formula: MAP = 1/3 * Systolic BP + 2/3 * Diastolic BP
# Use MAP as it captures the information of both systolic and diastolic BP in
# one feature, making the model more parsimonious
visit_df_cleaned_split_v2['MAP'] = 1/3 * visit_df_cleaned_split_v2['SystolicBP'] + 2/3 * visit_df_cleaned_split_v2['DiastolicBP']
visit_df_cleaned_split_v2['MAP_cleaned'] = 1/3 * visit_df_cleaned_split_v2['SystolicBP_imputed'] + 2/3 * visit_df_cleaned_split_v2['DiastolicBP_imputed']
# Display the first few rows to verify the new column
display(visit_df_cleaned_split_v2.head())
# Verify that there are no blanks in MAP_cleaned
print("\nRows in visit_df_cleaned_split_v2 where MAP_cleaned is empty:")
display(visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['MAP_cleaned'].isna()])
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | ... | Weight_missing | Weight_imputed | SystolicBP_missing | SystolicBP_imputed | DiastolicBP_missing | DiastolicBP_imputed | BMI_recalc | BMI_cleaned | MAP | MAP_cleaned | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | e3f9ca98-8201-4d45-9271-86ed19697f7a | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-01-13 21:43:58 | -05:00 | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1 | 191.25 | 1 | 124.0 | 1 | 78.0 | NaN | 29.508642 | NaN | 93.333333 |
| 1 | 70f3664e-c866-41db-a6fa-7737d7956b25 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-02-24 18:23:23 | -05:00 | 69.0 | 172.2 | 25.681 | 137.0 | 90.0 | NaN | ... | 0 | 172.20 | 0 | 137.0 | 0 | 90.0 | 25.426717 | 25.426717 | 105.666667 | 105.666667 |
| 2 | 67e314de-25a8-4187-bd98-c57c75b6aa14 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-04-24 13:21:27 | -05:00 | NaN | NaN | NaN | 140.0 | 72.0 | NaN | ... | 1 | 191.25 | 0 | 140.0 | 0 | 72.0 | NaN | 29.508642 | 94.666667 | 94.666667 |
| 3 | 11b534ce-3b18-4ebf-bdeb-9a1fab98740e | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-07-02 16:42:13 | -05:00 | NaN | NaN | NaN | 157.0 | 88.0 | NaN | ... | 1 | 191.25 | 0 | 157.0 | 0 | 88.0 | NaN | 29.508642 | 111.000000 | 111.000000 |
| 4 | 8a108567-fd09-40cf-8323-23e9e23ba9e4 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-07-11 16:03:27 | -05:00 | NaN | NaN | NaN | 138.0 | 85.0 | NaN | ... | 1 | 191.25 | 0 | 138.0 | 0 | 85.0 | NaN | 29.508642 | 102.666667 | 102.666667 |
5 rows × 30 columns
Rows in visit_df_cleaned_split_v2 where MAP_cleaned is empty:
| VisitGuid | PatientGuid | Timestamp | tz_offset | Height | Weight | BMI | SystolicBP | DiastolicBP | RespiratoryRate | ... | Weight_missing | Weight_imputed | SystolicBP_missing | SystolicBP_imputed | DiastolicBP_missing | DiastolicBP_imputed | BMI_recalc | BMI_cleaned | MAP | MAP_cleaned |
|---|
0 rows × 30 columns
Creation of BMI, Mean Arterial Pressure, and Number of Unique Physician Specialty Groups (12-Month History)¶
# Q2e Step 1
def add_visit_12m_features(uoa: pd.DataFrame, visit: pd.DataFrame, specialty: pd.DataFrame) -> pd.DataFrame:
"""
For each (PatientGuid, AsAtDate) in unit_of_analysis_df, compute over the lookback window (AsAtDate - 365 days, AsAtDate]:
- MedianBMI_p12m: median of Visit.BMI; if NaN, median of Visit.BMI_cleaned
- MedianMAP_p12m: median of Visit.MAP; if NaN, median of Visit.MAP_cleaned
- unique_ICD9CodesDesc_p12m: # unique Group1 descriptions
"""
# Join specialty to bring PhysicianSpecialty into visit
visit = visit.merge(specialty[['PhysicianSpecialty', 'SpecialtyGroup_cleaned']], on = 'PhysicianSpecialty', how = 'left')
# Sort for per-patient binary searches
visit = visit.sort_values(['PatientGuid', 'Timestamp']).reset_index(drop = True)
uoa = uoa.sort_values(['PatientGuid', 'AsAtDate']).reset_index(drop = True)
# Prepare result arrays aligned to uoa rows
med_bmi = np.full(len(uoa), np.nan, dtype = float)
bmi_missing = np.zeros(len(uoa), dtype = np.int32)
med_map = np.full(len(uoa), np.nan, dtype = float)
map_missing = np.zeros(len(uoa), dtype = np.int32)
uniq_specialties = np.zeros(len(uoa), dtype = np.int32)
# Group both frames by patient for efficient windowed counts
visit_grp = {pid: g for pid, g in visit.groupby('PatientGuid', sort = False)}
uoa_grp = uoa.groupby('PatientGuid', sort = False)
# Loop patients (fast enough because snapshots are monthly)
for pid, u in uoa_grp:
if pid not in visit_grp:
continue
g = visit_grp[pid]
times = g['Timestamp'].to_numpy()
bmis = g['BMI_recalc'].to_numpy()
bmis_cln = g['BMI_cleaned'].to_numpy()
maps = g['MAP'].to_numpy()
maps_cln = g['MAP_cleaned'].to_numpy()
specialties = g['SpecialtyGroup_cleaned'].to_numpy()
# For each AsAtDate, find indices of events in (S - 365, S]
idxs = u.index.to_numpy()
asofs = u['AsAtDate'].to_numpy()
for i, S in zip(idxs, asofs):
if pd.isna(S):
uniq_specialties[i] = 0
continue
start = S - np.timedelta64(365, 'D')
# Left bound: first event strictly > start (exclude exactly S - 365)
L = np.searchsorted(times, start, side = 'right')
# Right bound: last event <= S (inclusive)
R = np.searchsorted(times, S, side = 'right')
if L >= R:
uniq_specialties[i] = 0
else:
# Get number of unique specialty groups
uniq_specialties[i] = pd.unique(specialties[L:R]).size
# Get median BMI
if np.isnan(np.nanmedian(bmis[L:R])):
med_bmi[i] = float(np.median(bmis_cln[L:R]))
bmi_missing[i] = 1
else:
med_bmi[i] = float(np.nanmedian(bmis[L:R]))
# Get median BMI
if np.isnan(np.nanmedian(maps[L:R])):
med_map[i] = float(np.median(maps_cln[L:R]))
map_missing[i] = 1
else:
med_map[i] = float(np.nanmedian(maps[L:R]))
if np.isnan(med_bmi[i]):
bmi_missing[i] = 1
if np.isnan(med_map[i]):
map_missing[i] = 1
# Attach results to original frame
# The index from the sorted uoa is used to align the results
out = uoa.copy()
out['MedianBMI_p12m'] = med_bmi
out['MedianBMI_p12m_missing'] = bmi_missing
out['MedianMAP_p12m'] = med_map
out['MedianMAP_p12m_missing'] = map_missing
out['unique_PhysicianSpecialtyGroups_p12m'] = uniq_specialties
# Preserve original row order
original_order_uoa = uoa.sort_index().copy()
out = out.set_index(original_order_uoa.index)
return out
# Attach columns to unit of analysis
unit_of_analysis_df_v2 = add_visit_12m_features(unit_of_analysis_df_v2, visit_df_cleaned_split_v2, specialty_df_cleaned)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inserting median BMI and MAP:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Sample check a patient at a particular as-at-date to verify that the median
# BMI and MAP are correctly calculated when using raw values
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == 'fff9b5b2-4ec4-4260-b831-0def9f8bfe43') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2021-01-31').date())])
print(f"Median raw BMI: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'fff9b5b2-4ec4-4260-b831-0def9f8bfe43') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2020-01-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())].sort_values(by = 'Timestamp')['BMI_recalc'].median()}")
print(f"Median cleaned BMI: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'fff9b5b2-4ec4-4260-b831-0def9f8bfe43') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2020-01-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())].sort_values(by = 'Timestamp')['BMI_cleaned'].median()}")
print(f"Median raw MAP: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'fff9b5b2-4ec4-4260-b831-0def9f8bfe43') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2020-01-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())].sort_values(by = 'Timestamp')['MAP'].median()}")
print(f"Median cleaned MAP: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'fff9b5b2-4ec4-4260-b831-0def9f8bfe43') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2020-01-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())].sort_values(by = 'Timestamp')['MAP_cleaned'].median()}")
# Sample check a patient at a particular as-at-date to verify that the median
# BMI and MAP are correctly calculated when using imputed values
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == 'cac82153-1b73-480c-b767-81e4b68d5cf7') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2020-12-31').date())])
print(f"Median raw BMI: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'cac82153-1b73-480c-b767-81e4b68d5cf7') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2019-12-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2020-12-31').date())].sort_values(by = 'Timestamp')['BMI_recalc'].median()}")
print(f"Median cleaned BMI: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'cac82153-1b73-480c-b767-81e4b68d5cf7') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2019-12-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2020-12-31').date())].sort_values(by = 'Timestamp')['BMI_cleaned'].median()}")
print(f"Median raw MAP: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'cac82153-1b73-480c-b767-81e4b68d5cf7') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2019-12-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2020-12-31').date())].sort_values(by = 'Timestamp')['MAP'].median()}")
print(f"Median cleaned MAP: {visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == 'cac82153-1b73-480c-b767-81e4b68d5cf7') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2019-12-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2020-12-31').date())].sort_values(by = 'Timestamp')['MAP_cleaned'].median()}")
# There are observations with missing median BMI and MAP as their as-at-dates
# fall within a one-year period where the patient did not have visits
print(f"\nNumber of rows in unit_of_analysis_df_v2 where BMI is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['MedianBMI_p12m'].isna()].shape[0]}")
print(f"\nNumber of rows in unit_of_analysis_df_v2 where MAP is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['MedianMAP_p12m'].isna()].shape[0]}")
# Impute using the median of the other observations of the patient
unit_of_analysis_df_v2['MedianBMI_p12m'] = unit_of_analysis_df_v2.groupby('PatientGuid')['MedianBMI_p12m'].transform(lambda x: x.fillna(x.median()))
unit_of_analysis_df_v2['MedianMAP_p12m'] = unit_of_analysis_df_v2.groupby('PatientGuid')['MedianMAP_p12m'].transform(lambda x: x.fillna(x.median()))
# One patient left to manually impute
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['PatientGuid'] == '38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e'])
unit_of_analysis_df_v2.loc[unit_of_analysis_df_v2['PatientGuid'] == '38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e', 'MedianBMI_p12m'] = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['PatientGuid'] == '38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e']['BMI_cleaned'].median()
unit_of_analysis_df_v2.loc[unit_of_analysis_df_v2['PatientGuid'] == '38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e', 'MedianMAP_p12m'] = visit_df_cleaned_split_v2[visit_df_cleaned_split_v2['PatientGuid'] == '38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e']['MAP_cleaned'].median()
# Check for anymore observations with missing median BMI and MAP
print(f"\nNumber of rows in unit_of_analysis_df_v2 where BMI is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['MedianBMI_p12m'].isna()].shape[0]}")
print(f"\nNumber of rows in unit_of_analysis_df_v2 where MAP is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['MedianMAP_p12m'].isna()].shape[0]}")
# Check for any missing counts
print(f"\nNumber of rows in unit_of_analysis_df_v2 where the number of unique specialty groups is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['unique_PhysicianSpecialtyGroups_p12m'].isna()].shape[0]}")
# Sample check a patient at a particular as-at-date to verify that the number of
# unique specialty groups are are correctly calculated
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '0056cdf7-609c-4c4e-8acc-0aaef6f1998e') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2022-03-31').date())])
display(visit_df_cleaned_split_v2[(visit_df_cleaned_split_v2['PatientGuid'] == '0056cdf7-609c-4c4e-8acc-0aaef6f1998e') & (visit_df_cleaned_split_v2['Timestamp'].dt.date > pd.to_datetime('2021-03-31').date()) & (visit_df_cleaned_split_v2['Timestamp'].dt.date <= pd.to_datetime('2022-03-31').date())]['PhysicianSpecialty'].unique())
display(specialty_df_cleaned[specialty_df_cleaned['PhysicianSpecialty'].isin(['Internal Medicine', 'Unknown'])]['SpecialtyGroup_cleaned'].nunique())
Shape of unit_of_analysis_df before and after inserting median BMI and MAP:
(98704, 11)
(98704, 4)
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 98667 | fff9b5b2-4ec4-4260-b831-0def9f8bfe43 | 2021-01-31 23:59:59.999999999 | train | 0 | 37 | 0 | 34.889338 | 0 | 93.666667 | 0 | 1 |
Median raw BMI: 34.88933779926755 Median cleaned BMI: 39.600946745562126 Median raw MAP: 93.66666666666666 Median cleaned MAP: 93.66666666666666
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 77550 | cac82153-1b73-480c-b767-81e4b68d5cf7 | 2020-12-31 23:59:59.999999999 | train | 0 | 24 | 4 | 29.508642 | 1 | 93.333333 | 1 | 1 |
Median raw BMI: nan Median cleaned BMI: 29.50864197530864 Median raw MAP: nan Median cleaned MAP: 93.33333333333333 Number of rows in unit_of_analysis_df_v2 where BMI is empty: 9792 Number of rows in unit_of_analysis_df_v2 where MAP is empty: 9792
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 21728 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2022-08-31 23:59:59.999999999 | train | 1 | 86 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21729 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2022-09-30 23:59:59.999999999 | train | 1 | 86 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21730 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2022-10-31 23:59:59.999999999 | train | 1 | 86 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21731 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2022-11-30 23:59:59.999999999 | train | 1 | 86 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21732 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2022-12-31 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21733 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-01-31 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21734 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-02-28 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21735 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-03-31 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21736 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-04-30 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21737 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-05-31 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21738 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-06-30 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21739 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-07-31 23:59:59.999999999 | train | 1 | 87 | 0 | NaN | 1 | NaN | 1 | 0 |
| 21740 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-08-31 23:59:59.999999999 | train | 0 | 87 | 1 | NaN | 1 | NaN | 1 | 0 |
| 21741 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-09-30 23:59:59.999999999 | train | 0 | 87 | 1 | NaN | 1 | NaN | 1 | 0 |
| 21742 | 38c3f5c2-89e0-4ab6-a5c4-2bd4ef61e39e | 2023-10-31 23:59:59.999999999 | train | 0 | 87 | 1 | NaN | 1 | NaN | 1 | 0 |
Number of rows in unit_of_analysis_df_v2 where BMI is empty: 0 Number of rows in unit_of_analysis_df_v2 where MAP is empty: 0 Number of rows in unit_of_analysis_df_v2 where the number of unique specialty groups is empty: 0
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 50 | 0056cdf7-609c-4c4e-8acc-0aaef6f1998e | 2022-03-31 23:59:59.999999999 | test | 0 | 45 | 0 | 25.956923 | 0 | 104.0 | 0 | 2 |
array(['Internal Medicine', 'Unknown'], dtype=object)
2
Creation of Number of Abnormal Lab Results (12-Month History)¶
# Q2e Step 1
def add_abnormal_lab_results(uoa: pd.DataFrame, labresults: pd.DataFrame, labobs: pd.DataFrame) -> pd.DataFrame:
"""
For each (PatientGuid, AsAtDate) in unit_of_analysis_df, compute:
- NumberOfAbnormalLabResults_p12m: number of abnormal lab results in (AsAtDate-365d, AsAtDate]
"""
# Join labresults to get 'PatientGuid' and 'Timestamp'
lab = labobs.merge(labresults[['LabResultGuid', 'PatientGuid', 'Timestamp']], on = 'LabResultGuid', how = 'left')
# Sort for per-patient binary searches
lab = lab.sort_values(['PatientGuid', 'Timestamp']).reset_index(drop = True)
uoa = uoa.sort_values(['PatientGuid', 'AsAtDate']).reset_index(drop = True)
# Prepare result array aligned to uoa rows
num_abnormal_results = np.zeros(len(uoa), dtype = np.int32)
# Group both frames by patient for efficient windowed counts
lab_grp = {pid: g for pid, g in lab.groupby('PatientGuid', sort = False)}
uoa_grp = uoa.groupby('PatientGuid', sort = False)
# Loop patients (fast enough because snapshots are monthly)
for pid, u in uoa_grp:
if pid not in lab_grp:
# Default to 0
continue
l = lab_grp[pid]
times = l['Timestamp'].to_numpy()
abnormals = l['AnyAbnormalValue_cleaned'].to_numpy()
# For each AsAtDate, find indices of events in (S - 365, S]
idxs = u.index.to_numpy()
asofs = u['AsAtDate'].to_numpy()
for i, S in zip(idxs, asofs):
if pd.isna(S):
num_abnormal_results[i] = 0
continue
start = S - np.timedelta64(365, 'D')
# Left bound: first event strictly > start (exclude exactly S - 365)
L = np.searchsorted(times, start, side = 'right')
# Right bound: last event <= S (inclusive)
R = np.searchsorted(times, S, side = 'right')
if L >= R:
num_abnormal_results[i] = 0
else:
# Get number of abnormal lab results
num_abnormal_results[i] = (abnormals[L:R]).sum()
# Attach results to original frame
# The index from the sorted uoa is used to align the results
out = uoa.copy()
out['NumberOfAbnormalLabResults_p12m'] = num_abnormal_results
# Preserve original row order
original_order_uoa = uoa.sort_index().copy()
out = out.set_index(original_order_uoa.index)
return out
# Attach columns to unit of analysis
unit_of_analysis_df_v2 = add_abnormal_lab_results(unit_of_analysis_df_v2, labresult_df, labobservation_df_cleaned_v2)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inserting number of abnormal lab results:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Check for any missing counts
print(f"\nNumber of rows in unit_of_analysis_df_v2 where the number of abnormal lab results is empty: {unit_of_analysis_df_v2[unit_of_analysis_df_v2['NumberOfAbnormalLabResults_p12m'].isna()].shape[0]}")
# Sample check a patient with 0 abnormal lab results at a particular as-at-date
# to verify that the number has been correctly calculated
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2021-01-31').date())])
display(labresult_df[(labresult_df['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (labresult_df['Timestamp'].dt.date > pd.to_datetime('2020-01-31').date()) & (labresult_df['Timestamp'].dt.date <= pd.to_datetime('2021-01-31').date())])
# Sample check a patient with 2 abnormal lab results at a particular as-at-date
# to verify that the number has been correctly calculated
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00c5e26d-e323-47c2-bfcb-a2e9fe95f86d') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2022-01-31').date())])
display(labresult_df[(labresult_df['PatientGuid'] == '00c5e26d-e323-47c2-bfcb-a2e9fe95f86d') & (labresult_df['Timestamp'].dt.date > pd.to_datetime('2021-01-31').date()) & (labresult_df['Timestamp'].dt.date <= pd.to_datetime('2022-01-31').date())])
display(labobservation_df_cleaned_v2[labobservation_df_cleaned_v2['LabResultGuid'].isin(['381feedd-90c2-4782-92dd-28bcc2a64c8f', '205fe8c2-34ea-4211-a4ca-670a50d0d236', '66d4e3f8-e799-4623-a09e-1f27a3cdcec1'])])
Shape of unit_of_analysis_df before and after inserting number of abnormal lab results:
(98704, 12)
(98704, 4)
Number of rows in unit_of_analysis_df_v2 where the number of abnormal lab results is empty: 0
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-01-31 23:59:59.999999999 | test | 0 | 26 | 2 | 25.426717 | 0 | 102.5 | 0 | 1 | 0 |
| LabResultGuid | PatientGuid | Timestamp | tz_offset |
|---|
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 84 | 00c5e26d-e323-47c2-bfcb-a2e9fe95f86d | 2022-01-31 23:59:59.999999999 | val | 0 | 69 | 5 | 25.604923 | 0 | 97.666667 | 0 | 1 | 2 |
| LabResultGuid | PatientGuid | Timestamp | tz_offset | |
|---|---|---|---|---|
| 2544 | 381feedd-90c2-4782-92dd-28bcc2a64c8f | 00c5e26d-e323-47c2-bfcb-a2e9fe95f86d | 2021-03-25 20:49:45 | -05:00 |
| 2593 | 205fe8c2-34ea-4211-a4ca-670a50d0d236 | 00c5e26d-e323-47c2-bfcb-a2e9fe95f86d | 2021-03-30 14:23:52 | -05:00 |
| 2788 | 66d4e3f8-e799-4623-a09e-1f27a3cdcec1 | 00c5e26d-e323-47c2-bfcb-a2e9fe95f86d | 2021-04-27 15:50:47 | -05:00 |
| LabResultGuid | AnyAbnormalValue_cleaned | |
|---|---|---|
| 1721 | 205fe8c2-34ea-4211-a4ca-670a50d0d236 | True |
| 2977 | 381feedd-90c2-4782-92dd-28bcc2a64c8f | True |
| 5417 | 66d4e3f8-e799-4623-a09e-1f27a3cdcec1 | False |
Creation of Smoker Status¶
# Q2e Step 1
# Check for rows where validity periods either overlap with the previous one or
# have a start date that comes after the end date
smoking_df['prev_ValidTo'] = smoking_df.sort_values(by = ['PatientGuid', 'ValidFrom']).groupby('PatientGuid')['ValidTo'].shift(1)
display(smoking_df[(smoking_df['ValidFrom'] < smoking_df['prev_ValidTo']) | (smoking_df['ValidTo'] < smoking_df['ValidFrom'])])
# Count the number of rows of unique PatientGuid and smoker description combinations
smoking_counts = smoking_df_cleaned.groupby('PatientGuid')['Description'].nunique().reset_index(name = 'count')
# Display PatientGuids where the count is greater than 1
multiple_desc = smoking_counts[smoking_counts['count'] > 1]
if not multiple_desc.empty:
print("PatientGuid and smoker description combinations appearing more than once:")
display(multiple_desc)
else:
print("No duplicate PatientGuid and smoker description combinations found.")
# Attach smoker description information from smoking_df_cleaned into
# unit_of_analysis_df_v2 and map the descriptions to either smoker or non-smoker
# Need to handle validity periods as a patient can change smoker description
# over time
# Treat validity period intervals as "[)"
def attach_smoker_status(uoa: pd.DataFrame, smoker: pd.DataFrame) -> pd.DataFrame:
"""
Attach smoker status to each row in unit_of_analysis_df while respecting validity periods [ValidFrom, ValidTo) (i.e. left-closed, right-open).
"""
# Join last patient row with ValidFrom <= AsAtDate within each PatientGuid
merged = pd.merge_asof(uoa.sort_values('AsAtDate'), smoker.sort_values('ValidFrom')[['PatientGuid', 'ValidFrom', 'ValidTo', 'Description_cleaned']], left_on = 'AsAtDate', right_on = 'ValidFrom', by = 'PatientGuid')
in_window = (merged['AsAtDate'] >= merged['ValidFrom']) & (merged['AsAtDate'] < merged['ValidTo'])
merged.loc[~in_window, 'Description_cleaned'] = pd.NA
# merged['SmokerStatus'] = np.full(len(uoa), pd.NA, dtype = np.int8)
# Map descriptions to either smoker or non-smoker
merged['SmokerStatus'] = np.where((merged['Description_cleaned'] == '0 cigarettes per day (previous smoker)') | (merged['Description_cleaned'] == '0 cigarettes per day (non-smoker or less than 100 in lifetime)'), 0, np.where(merged['Description_cleaned'].notna(), 1, np.nan))
# Drop 'Description_cleaned'
merged = merged.drop(columns = ['Description_cleaned'])
# Preserve original row order
merged = merged.sort_index()
# Drop columns "ValidFrom" and "ValidTo"
merged = merged.drop(columns = ["ValidFrom", "ValidTo"])
return merged
# Attach smoker status to unit of analysis
unit_of_analysis_df_v2 = attach_smoker_status(unit_of_analysis_df_v2, smoking_df_cleaned)
# Convert SmokerStatus to nullable int data type
unit_of_analysis_df_v2['SmokerStatus'] = unit_of_analysis_df_v2['SmokerStatus'].astype('Int8')
# Identify PatientGuids in unit_of_analysis_df_v2 not present in
# smoking_df_cleaned
patients_not_in_smoking = unit_of_analysis_df_v2[~unit_of_analysis_df_v2['PatientGuid'].isin(smoking_df_cleaned['PatientGuid'].unique())]['PatientGuid'].unique()
print(f"Percentage of PatientGuids in unit of analysis with no smoking information: {len(patients_not_in_smoking) / len(unit_of_analysis_df_v2['PatientGuid'].unique()) * 100}%")
# Assume non-smoker for these PatientGuids
unit_of_analysis_df_v2.loc[unit_of_analysis_df_v2['PatientGuid'].isin(patients_not_in_smoking), 'SmokerStatus'] = 0
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inclusion of smoker status information:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Display rows where smoker status is empty
print("\nRows in unit_of_analysis_df_v2 where smoker status is empty:")
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['SmokerStatus'].isna()])
# Impute these rows with the mode of smoker status within each PatientGuid,
# assuming non-smoker status in case of more than one mode
unit_of_analysis_df_v2['SmokerStatus'] = unit_of_analysis_df_v2.groupby('PatientGuid')['SmokerStatus'].transform(lambda x: x.fillna(0 if x.mode().empty else (0 if (0 in x.mode().values and 1 in x.mode().values) else x.mode().iloc[0])))
# Check for any more rows with missing smoker status
print("\nRows in unit_of_analysis_df_v2 where smoker status is empty:")
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['SmokerStatus'].isna()])
# Sample check a patient with changing smoker status to verify that the smoker
# statuses are correctly mapped
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['PatientGuid'] == '39b1ac90-9407-43c4-a850-50b27f86d299'])
display(smoking_df_cleaned[smoking_df_cleaned['PatientGuid'] == '39b1ac90-9407-43c4-a850-50b27f86d299'])
| PatientSmokingStatusGuid | PatientGuid | ValidFrom | ValidTo | Description | NISTcode | prev_ValidTo |
|---|
PatientGuid and smoker description combinations appearing more than once:
| PatientGuid | count | |
|---|---|---|
| 18 | 030b6c4a-52b5-4acc-b415-bfeca3bc9857 | 2 |
| 33 | 0437f2fa-3df4-42e2-82b7-19f968e15c45 | 2 |
| 37 | 04d9fc0c-4be2-4029-bc24-ad71c2441ad5 | 2 |
| 131 | 1167d1a9-4113-4160-85e2-bfa487c7d08b | 2 |
| 134 | 117627f9-6cd5-49df-973d-0644bc0fc4c8 | 2 |
| ... | ... | ... |
| 1815 | eaa0096e-f1ab-4bf5-b4da-0e89ede7951a | 2 |
| 1853 | ef0e8744-1abd-4c62-a6e7-44b6b8927cb6 | 2 |
| 1885 | f3aa08ef-07d5-4242-bbfe-8c67675698a3 | 2 |
| 1901 | f5d39ae9-2fd3-4abb-ab10-261b9f7377bc | 2 |
| 1986 | ff67aeb6-44bb-4ca3-9fad-9242f2ea6f1c | 2 |
92 rows × 2 columns
Percentage of PatientGuids in unit of analysis with no smoking information: 36.03603603603604% Shape of unit_of_analysis_df before and after inclusion of smoker status information:
(98704, 13)
(98704, 4)
Rows in unit_of_analysis_df_v2 where smoker status is empty:
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 16 | 2de2e1b6-84d7-4797-b34c-b2abb7e4e895 | 2020-12-31 23:59:59.999999999 | train | 1 | 72 | 2 | 39.814818 | 0 | 108.000000 | 0 | 1 | 1 | <NA> |
| 37 | 33f22636-8784-4fb8-9191-aee1d3afc007 | 2020-12-31 23:59:59.999999999 | val | 0 | 80 | 7 | 29.626951 | 0 | 83.333333 | 0 | 1 | 0 | <NA> |
| 62 | 3b0f07d3-e4e6-4ae3-9d97-e5d528a5d211 | 2020-12-31 23:59:59.999999999 | val | 0 | 21 | 2 | 29.118343 | 1 | 108.000000 | 0 | 1 | 1 | <NA> |
| 90 | ab1d5442-78b4-4241-a6f3-f6efab5a0a8d | 2021-01-31 23:59:59.999999999 | train | 1 | 23 | 1 | 25.041775 | 0 | 79.333333 | 0 | 1 | 0 | <NA> |
| 100 | 03d9d07b-ad4e-4345-b186-99dbc9b809be | 2021-01-31 23:59:59.999999999 | train | 0 | 53 | 2 | 39.101775 | 1 | 88.666667 | 0 | 1 | 0 | <NA> |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 98485 | e9a117b2-88cd-40f5-b90a-8498c59ecac9 | 2023-12-31 23:59:59.999999999 | train | 0 | 21 | 3 | 33.694083 | 1 | 92.166667 | 0 | 1 | 0 | <NA> |
| 98523 | e3d67627-846c-4171-9b49-bab443263706 | 2023-12-31 23:59:59.999999999 | test | 0 | 51 | 6 | 29.508642 | 1 | 93.333333 | 1 | 1 | 1 | <NA> |
| 98602 | 3efe92bd-a2d7-4598-8ebf-6fddf2a61c92 | 2023-12-31 23:59:59.999999999 | val | 0 | 69 | 1 | 29.118343 | 1 | 93.333333 | 1 | 1 | 1 | <NA> |
| 98603 | 5a47dc3a-bb3e-44c2-b41d-a161a6dce14a | 2023-12-31 23:59:59.999999999 | train | 1 | 20 | 1 | 24.625799 | 0 | 86.000000 | 0 | 1 | 2 | <NA> |
| 98670 | 197acb53-064c-4530-a0cc-98de7d7f5cd4 | 2024-01-31 23:59:59.999999999 | train | 0 | 41 | 0 | 37.956214 | 1 | 89.666667 | 0 | 1 | 1 | <NA> |
5365 rows × 13 columns
Rows in unit_of_analysis_df_v2 where smoker status is empty:
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus |
|---|
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1136 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-01-31 23:59:59.999999999 | val | 0 | 72 | 0 | 42.774324 | 0 | 98.666667 | 0 | 1 | 0 | 0 |
| 2284 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-02-28 23:59:59.999999999 | val | 0 | 72 | 0 | 42.774324 | 0 | 98.666667 | 0 | 1 | 0 | 0 |
| 3618 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-03-31 23:59:59.999999999 | val | 0 | 72 | 1 | 42.840651 | 0 | 97.833333 | 0 | 1 | 0 | 0 |
| 5474 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-04-30 23:59:59.999999999 | val | 0 | 73 | 1 | 42.774324 | 0 | 97.000000 | 0 | 1 | 0 | 0 |
| 7463 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-05-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.707997 | 0 | 97.833333 | 0 | 1 | 0 | 0 |
| 9616 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-06-30 23:59:59.999999999 | val | 0 | 73 | 1 | 42.774324 | 0 | 97.833333 | 0 | 1 | 0 | 0 |
| 12137 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-07-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.774324 | 0 | 97.000000 | 0 | 1 | 0 | 0 |
| 13850 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-08-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.774324 | 0 | 98.666667 | 0 | 1 | 0 | 0 |
| 16518 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-09-30 23:59:59.999999999 | val | 0 | 73 | 1 | 42.774324 | 0 | 98.666667 | 0 | 1 | 0 | 0 |
| 20022 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-10-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.840651 | 0 | 98.666667 | 0 | 1 | 0 | 0 |
| 23597 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-11-30 23:59:59.999999999 | val | 0 | 73 | 1 | 42.754756 | 0 | 99.333333 | 0 | 1 | 0 | 0 |
| 27386 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2021-12-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.754756 | 0 | 100.000000 | 0 | 1 | 1 | 0 |
| 29673 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-01-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.754756 | 0 | 98.333333 | 0 | 1 | 1 | 1 |
| 31917 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-02-28 23:59:59.999999999 | val | 0 | 73 | 1 | 42.754756 | 0 | 98.333333 | 0 | 1 | 1 | 1 |
| 36301 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-03-31 23:59:59.999999999 | val | 0 | 73 | 1 | 42.313233 | 0 | 100.000000 | 0 | 1 | 1 | 1 |
| 38438 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-04-30 23:59:59.999999999 | val | 0 | 74 | 1 | 42.313233 | 0 | 101.333333 | 0 | 1 | 1 | 1 |
| 41303 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-05-31 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 1 | 1 |
| 44621 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-06-30 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 1 | 1 |
| 48429 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-07-31 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 1 | 1 |
| 48972 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-08-31 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 2 | 1 |
| 53779 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-09-30 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 2 | 1 |
| 56276 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-10-31 23:59:59.999999999 | val | 0 | 74 | 1 | 42.731377 | 1 | 97.833333 | 1 | 0 | 2 | 1 |
| 60683 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-11-30 23:59:59.999999999 | val | 0 | 74 | 2 | 42.731377 | 1 | 97.833333 | 1 | 0 | 2 | 1 |
| 61423 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-12-31 23:59:59.999999999 | val | 0 | 74 | 2 | 42.731377 | 1 | 97.833333 | 1 | 0 | 1 | 1 |
| 67190 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-01-31 23:59:59.999999999 | val | 0 | 74 | 2 | 30.017906 | 0 | 94.666667 | 0 | 1 | 1 | 1 |
| 69636 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-02-28 23:59:59.999999999 | val | 0 | 74 | 2 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 71035 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-03-31 23:59:59.999999999 | val | 0 | 74 | 2 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 75887 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-04-30 23:59:59.999999999 | val | 0 | 75 | 2 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 77671 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-05-31 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 82326 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-06-30 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 82881 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-07-31 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 2 | 1 |
| 87023 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-08-31 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 1 | 1 |
| 90229 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-09-30 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 1 | 1 |
| 93783 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2023-10-31 23:59:59.999999999 | val | 0 | 75 | 1 | 30.017906 | 0 | 94.666667 | 0 | 1 | 1 | 1 |
| PatientSmokingStatusGuid | PatientGuid | ValidFrom | ValidTo | Description | NISTcode | Description_cleaned | |
|---|---|---|---|---|---|---|---|
| 389 | 9d3b93ab-5f1e-4909-853d-9b6d4494ad95 | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2020-01-01 00:00:00 | 2022-01-05 20:45:09 | 0 cigaretttes per day (non-smoker or less than... | 4 | 0 cigarettes per day (non-smoker or less than ... |
| 1811 | a558314f-07fa-4a17-822e-9124eb3d4afd | 39b1ac90-9407-43c4-a850-50b27f86d299 | 2022-01-05 20:45:09 | 2099-12-31 23:59:59 | 1-2 packs per day | 1 | 1-2 packs per day |
Creation of Population Density¶
# Q2e Step 1
statedetails_df_cleaned = statedetails_df.copy()
# Calculate population density assuming area is in square miles
statedetails_df_cleaned['PopulationDensity'] = statedetails_df_cleaned['TotalPopulation'] / statedetails_df_cleaned['Area']
# Display the first few rows to verify the new column
display(statedetails_df_cleaned.head())
| StateGuid | StateCode | StateName | CentroidLatitude | CentroidLongitude | Area | CensusRegion | HospitalCount | HospitalBedCount | BelowPovertyLevel | Aged65Plus | TotalPopulation | ValidFrom | ValidTo | PopulationDensity | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 417f1f50-56a8-46f2-a092-985c489e28c5 | AK | Alaska | 61.3850 | -152.2683 | 665384 | Pacific | 11 | 1288 | 0.1034 | 0.137 | 719445 | 2010-01-01 | 2099-12-31 23:59:59 | 1.081248 |
| 1 | 10544dfe-458b-44b9-90ad-8ca15c34ea91 | AL | Alabama | 32.7990 | -86.8073 | 52420 | East South Central | 90 | 14994 | 0.1598 | 0.186 | 4771614 | 2010-01-01 | 2099-12-31 23:59:59 | 91.026593 |
| 2 | 8a7b3c02-7c90-45cf-b724-489a915d878e | AR | Arkansas | 34.9513 | -92.3809 | 53178 | West South Central | 52 | 8069 | 0.1608 | 0.180 | 2923585 | 2010-01-01 | 2099-12-31 23:59:59 | 54.977340 |
| 3 | 83d1b0e0-207f-4b95-91d0-5f35cae49505 | AS | American Samoa | 14.2417 | -170.7197 | 581 | Polynesia | 1 | 131 | 0.6500 | 0.052 | 55103 | 2010-01-01 | 2099-12-31 23:59:59 | 94.841652 |
| 4 | 4d0241cb-8708-416f-a61e-9e3ffc2249f5 | AZ | Arizona | 33.7712 | -111.3877 | 113990 | Mountain | 76 | 13866 | 0.1412 | 0.190 | 7012999 | 2010-01-01 | 2099-12-31 23:59:59 | 61.522932 |
Creation of Hospital Counts and Bed Counts per 100k¶
# Q2e Step 1
# Calculate hospital count per 100k of population
statedetails_df_cleaned['HospitalCountPer100k'] = statedetails_df_cleaned['HospitalCount'] / statedetails_df_cleaned['TotalPopulation'] * 100000
# Calculate bed count per 100k of population
statedetails_df_cleaned['BedCountPer100k'] = statedetails_df_cleaned['HospitalBedCount'] / statedetails_df_cleaned['TotalPopulation'] * 100000
# Display the first few rows to verify the new columns
display(statedetails_df_cleaned.head())
# Get descriptive statistics of hospital count per 100k and bed count per 100k
display(statedetails_df_cleaned[['HospitalCountPer100k', 'BedCountPer100k']].describe())
| StateGuid | StateCode | StateName | CentroidLatitude | CentroidLongitude | Area | CensusRegion | HospitalCount | HospitalBedCount | BelowPovertyLevel | Aged65Plus | TotalPopulation | ValidFrom | ValidTo | PopulationDensity | HospitalCountPer100k | BedCountPer100k | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 417f1f50-56a8-46f2-a092-985c489e28c5 | AK | Alaska | 61.3850 | -152.2683 | 665384 | Pacific | 11 | 1288 | 0.1034 | 0.137 | 719445 | 2010-01-01 | 2099-12-31 23:59:59 | 1.081248 | 1.528956 | 179.026889 |
| 1 | 10544dfe-458b-44b9-90ad-8ca15c34ea91 | AL | Alabama | 32.7990 | -86.8073 | 52420 | East South Central | 90 | 14994 | 0.1598 | 0.186 | 4771614 | 2010-01-01 | 2099-12-31 23:59:59 | 91.026593 | 1.886154 | 314.233297 |
| 2 | 8a7b3c02-7c90-45cf-b724-489a915d878e | AR | Arkansas | 34.9513 | -92.3809 | 53178 | West South Central | 52 | 8069 | 0.1608 | 0.180 | 2923585 | 2010-01-01 | 2099-12-31 23:59:59 | 54.977340 | 1.778638 | 275.996764 |
| 3 | 83d1b0e0-207f-4b95-91d0-5f35cae49505 | AS | American Samoa | 14.2417 | -170.7197 | 581 | Polynesia | 1 | 131 | 0.6500 | 0.052 | 55103 | 2010-01-01 | 2099-12-31 23:59:59 | 94.841652 | 1.814783 | 237.736602 |
| 4 | 4d0241cb-8708-416f-a61e-9e3ffc2249f5 | AZ | Arizona | 33.7712 | -111.3877 | 113990 | Mountain | 76 | 13866 | 0.1412 | 0.190 | 7012999 | 2010-01-01 | 2099-12-31 23:59:59 | 61.522932 | 1.083702 | 197.718551 |
| HospitalCountPer100k | BedCountPer100k | |
|---|---|---|
| count | 56.000000 | 56.000000 |
| mean | 1.408699 | 232.014985 |
| std | 0.469770 | 52.206488 |
| min | 0.813843 | 127.784493 |
| 25% | 1.056866 | 197.220044 |
| 50% | 1.310573 | 226.321931 |
| 75% | 1.698221 | 270.371955 |
| max | 2.823828 | 338.111335 |
Creation of Number of Prescriptions and Median Number of Refills (12-Month History)¶
# Q2e Step 1
def add_prescription_12m_features(uoa: pd.DataFrame, prescription: pd.DataFrame, empty_median_fill: float = 0.0) -> pd.DataFrame:
"""
For each (PatientGuid, AsAtDate) in unit_of_analysis_df, compute over the lookback window (AsAtDate - 365 days, AsAtDate]:
- NumberOfPrescriptions_p12m: number of prescription rows
- MedianNumberOfRefills_p12m: median of 'Refills' (0 if no rows)
"""
# Sort for per-patient binary searches
rx = prescription.sort_values(['PatientGuid', 'Timestamp']).reset_index(drop = True)
uoa = uoa.sort_values(['PatientGuid', 'AsAtDate']).reset_index(drop = True)
# Prepare result arrays aligned to uoa rows
num_rx = np.zeros(len(uoa), dtype = np.int32)
med_ref = np.full(len(uoa), empty_median_fill, dtype = float)
# Group both frames by patient for efficient windowed counts
rx_grp = {pid: g for pid, g in rx.groupby('PatientGuid', sort = False)}
uoa_grp = uoa.groupby('PatientGuid', sort = False)
# Loop patients (fast enough because snapshots are monthly)
for pid, u in uoa_grp:
if pid not in rx_grp:
# Default to 0 and empty_median_fill
continue
g = rx_grp[pid]
times = g['Timestamp'].to_numpy()
refills = g['NumberOfRefills'].to_numpy()
# For each AsAtDate, find indices of events in (S - 365, S]
idxs = u.index.to_numpy()
asofs = u['AsAtDate'].to_numpy()
for i, S in zip(idxs, asofs):
if pd.isna(S):
num_rx[i] = 0
med_ref[i] = empty_median_fill
continue
start = S - np.timedelta64(365, 'D')
# Left bound: first event strictly > start (exclude exactly S - 365)
L = np.searchsorted(times, start, side = 'right')
# Right bound: last event <= S (inclusive)
R = np.searchsorted(times, S, side = 'right')
if L >= R:
num_rx[i] = 0
med_ref[i] = empty_median_fill
else:
# Get number of prescriptions
num_rx[i] = R - L
# Get median number of refills
med_ref[i] = float(np.median(refills[L:R]))
# Attach results to original frame
# The index from the sorted uoa is used to align the results
out = uoa.copy()
out['NumberOfPrescriptions_p12m'] = num_rx
out['MedianNumberOfRefills_p12m'] = med_ref
# Preserve original row order
original_order_uoa = uoa.sort_index().copy()
out = out.set_index(original_order_uoa.index)
return out
# Attach columns to unit of analysis
unit_of_analysis_df_v2 = add_prescription_12m_features(unit_of_analysis_df_v2, prescription_df_cleaned_v2)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inserting number of prescriptions and median number of refills:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Sample check a patient at a particular as-at-date to verify that the number of
# prescriptions and median number of refills are correctly calculated
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2021-09-30').date())])
display(prescription_df_cleaned_v2[(prescription_df_cleaned_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (prescription_df_cleaned_v2['Timestamp'].dt.date <= pd.to_datetime('2021-09-30').date())].sort_values(by = 'Timestamp'))
Shape of unit_of_analysis_df before and after inserting number of prescriptions and median number of refills:
(98704, 15)
(98704, 4)
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 8 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-09-30 23:59:59.999999999 | test | 0 | 27 | 6 | 29.508642 | 1 | 102.333333 | 0 | 1 | 0 | 0 | 4 | 1.5 |
| PrescriptionGuid | PatientGuid | Timestamp | tz_offset | Quantity | NumberOfRefills | RefillAsNeeded | GenericAllowed | NdcCode | MedicationName | MedicationStrength | Schedule | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 112 | efd3bce2-dccf-43c2-9530-407f38317fc2 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-05-11 19:46:41 | -04:00 | 1 | 0 | False | True | 0008-1222 | Pristiq (desvenlafaxine) oral tablet, extended... | 100 mg | NaN |
| 113 | 7cad124f-4250-48ca-b06b-79b5655fa325 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2020-09-23 19:46:21 | -04:00 | 30 | 0 | False | True | 0002-3270 | Cymbalta (DULoxetine) oral delayed release cap... | 60 mg | NaN |
| 114 | c2944a6b-e650-4f94-8d33-244c69fc690e | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-02-16 19:12:37 | -04:00 | 90 | 0 | False | True | 0008-1222 | Pristiq (desvenlafaxine) oral tablet, extended... | 100 mg | NaN |
| 115 | 1eb25449-2816-43ea-9fe8-140ba91d9941 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-05-26 18:00:20 | -04:00 | 3 | 3 | False | True | 0002-3270 | Cymbalta (DULoxetine) oral delayed release cap... | 60 mg | NaN |
| 116 | 15f0e571-4352-4ba9-ae84-3055b929e228 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-09-14 13:23:48 | -04:00 | 15 | 3 | False | True | 0004-1964 | Rocephin (cefTRIAXone) injectable powder for i... | 1 g | NaN |
| 117 | 69183936-f2af-4387-86c5-f20eb1b1a4d1 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2021-09-16 17:32:20 | -04:00 | 1 | 0 | False | True | 0004-1964 | Rocephin (cefTRIAXone) injectable powder for i... | 1 g | 2.0 |
Q2e Step 2 - Data Consolidation (i.e. Joining the Tables and Keeping Relevant Features)¶
From 'Patient' Table¶
# Q2e Step 2
# Count the number of rows of unique PatientGuid and Gender/StateCode/BloodType combinations
gender_counts = patient_df_cleaned.groupby('PatientGuid')['Gender'].nunique().reset_index(name = 'count')
statecode_counts = patient_df_cleaned.groupby('PatientGuid')['StateCode'].nunique().reset_index(name = 'count')
bloodtype_counts = patient_df_cleaned.groupby('PatientGuid')['BloodType'].nunique().reset_index(name = 'count')
# Display PatientGuids where the count is greater than 1
multiple_gender = gender_counts[gender_counts['count'] > 1]
multiple_statecode = statecode_counts[statecode_counts['count'] > 1]
multiple_bloodtype = bloodtype_counts[bloodtype_counts['count'] > 1]
if not multiple_gender.empty:
print("PatientGuid and Gender combinations appearing more than once:")
display(multiple_gender)
else:
print("No duplicate PatientGuid and Gender combinations found.")
if not multiple_statecode.empty:
print("PatientGuid and StateCode combinations appearing more than once:")
display(multiple_statecode)
else:
print("No duplicate PatientGuid and StateCode combinations found.")
if not multiple_bloodtype.empty:
print("PatientGuid and BloodType combinations appearing more than once:")
display(multiple_bloodtype)
else:
print("No duplicate PatientGuid and BloodType combinations found.")
# Attach state code information from patient_df_cleaned into
# unit_of_analysis_df_v2
# Need to handle validity periods as a patient can change state codes over time
# Treat validity period intervals as "[)"
def attach_statecodes(uoa: pd.DataFrame, patient: pd.DataFrame) -> pd.DataFrame:
"""
Attach StateCode to each row in unit_of_analysis_df while respecting validity periods [ValidFrom, ValidTo) (i.e. left-closed, right-open).
"""
# Join last patient row with ValidFrom <= AsAtDate within each PatientGuid
merged = pd.merge_asof(uoa.sort_values('AsAtDate'), patient.sort_values('ValidFrom')[['PatientGuid', 'ValidFrom', 'ValidTo', 'StateCode']], left_on = 'AsAtDate', right_on = 'ValidFrom', by = 'PatientGuid')
in_window = (merged['AsAtDate'] >= merged['ValidFrom']) & (merged['AsAtDate'] < merged['ValidTo'])
merged.loc[~in_window, 'StateCode'] = pd.NA
# Preserve original row order
merged = merged.sort_index()
# Drop columns "ValidFrom" and "ValidTo"
merged = merged.drop(columns = ["ValidFrom", "ValidTo"])
return merged
# Attach state codes to unit of analysis
unit_of_analysis_df_v2 = attach_statecodes(unit_of_analysis_df_v2, patient_df_cleaned)
# Get gender and blood type into unit of analysis
# No need to use validity periods since each patient only has one gender and
# blood type and we do not expect these information to be time-varying
unit_of_analysis_df_v2 = pd.merge(unit_of_analysis_df_v2, patient_df_cleaned[['PatientGuid', 'Gender', 'BloodType']].drop_duplicates(), on = 'PatientGuid', how = 'left')
# Convert gender to F: 0 and M: 1
unit_of_analysis_df_v2['Gender'] = unit_of_analysis_df_v2['Gender'].map({'M': 1, 'F': 0}).astype(int)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inclusion of state code, gender, and blood type information:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Display rows where state code or gender or blood type is empty
print("\nRows in unit_of_analysis_df_v2 where statecode is empty:")
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['StateCode'].isna()])
print(f"\nRows in unit_of_analysis_df_v2 where gender is empty: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2['Gender'].isna()])}")
print(f"\nRows in unit_of_analysis_df_v2 where blood type is empty: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2['BloodType'].isna()])}")
# Retrieve relevant rows from patient_df_cleaned
display(patient_df_cleaned[(patient_df_cleaned['PatientGuid'] == '618ad33e-b907-4cb6-a0a9-4889ca284cf1') | (patient_df_cleaned['PatientGuid'] == 'bd9a05c0-a603-4e20-8131-a3a5dff38653')])
# Four rows (corresponding to two patients) have missing state code information
# as their as-at-dates are before their corresponding "valid from" dates in the
# 'patient' table. However, they only appear once each in the 'patient' table,
# so we assume that their state code has not changed and, hence, assume that the
# "valid from" dates for these two patients are erroneous. We bypass this by
# manually adding their state codes for the relevant rows.
# Set statecode = 'MN' for patientguid = '64d0d2ac-9586-40cf-b858-e755a7c4714b'
unit_of_analysis_df_v2.loc[unit_of_analysis_df_v2['PatientGuid'] == '618ad33e-b907-4cb6-a0a9-4889ca284cf1', 'StateCode'] = 'MN'
# Set statecode = 'PA' for patientguid = 'bd9a05c0-a603-4e20-8131-a3a5dff38653'
unit_of_analysis_df_v2.loc[unit_of_analysis_df_v2['PatientGuid'] == 'bd9a05c0-a603-4e20-8131-a3a5dff38653', 'StateCode'] = 'PA'
# Check for any more rows with missing state code
print("\nRows in unit_of_analysis_df_v2 where statecode is empty:")
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['StateCode'].isna()])
# Sample check a patient with changing state codes to verify that the change is
# consistent with the information in the 'patient' table
# Use same patient to verify that the gender and blood type are consistent with
# the information in the 'patient' table
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['PatientGuid'] == '018742a7-711a-471a-bf10-49bad6c163e0'])
display(patient_df_cleaned[patient_df_cleaned['PatientGuid'] == '018742a7-711a-471a-bf10-49bad6c163e0'])
No duplicate PatientGuid and Gender combinations found. PatientGuid and StateCode combinations appearing more than once:
| PatientGuid | count | |
|---|---|---|
| 4 | 00fc0c07-ebea-4f03-b81f-76a10079b603 | 2 |
| 8 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2 |
| 16 | 0234b7a6-f11d-4e63-9c9f-2e21e00379ae | 2 |
| 19 | 0276f514-f76e-421f-b6a1-6520d319aae7 | 2 |
| 27 | 03116aa8-fcb5-407e-bfa6-92594d3ab5b4 | 2 |
| ... | ... | ... |
| 3076 | fbdeebd0-e7b7-4e64-aeb3-94319273457c | 2 |
| 3080 | fc2085e6-42a3-49aa-a8ed-1e80e4c266c5 | 2 |
| 3083 | fc818408-e9cf-49b6-8220-89f72753db74 | 2 |
| 3106 | fedfacd0-d98e-482e-b18e-cd044bfa41a5 | 3 |
| 3116 | ff98f430-b82b-4f28-a006-ee8214dba787 | 2 |
455 rows × 2 columns
No duplicate PatientGuid and BloodType combinations found. Shape of unit_of_analysis_df before and after inclusion of state code, gender, and blood type information:
(98704, 18)
(98704, 4)
Rows in unit_of_analysis_df_v2 where statecode is empty:
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 34631 | 618ad33e-b907-4cb6-a0a9-4889ca284cf1 | 2022-03-31 23:59:59.999999999 | val | 1 | 51 | 0 | 31.129490 | 1 | 91.666667 | 1 | 0 | 0 | 0 | 0 | 0.0 | <NA> | 0 | B+ |
| 37450 | bd9a05c0-a603-4e20-8131-a3a5dff38653 | 2022-04-30 23:59:59.999999999 | val | 0 | 57 | 0 | 31.975702 | 1 | 99.333333 | 1 | 0 | 0 | 0 | 0 | 0.0 | <NA> | 0 | O+ |
| 39192 | 618ad33e-b907-4cb6-a0a9-4889ca284cf1 | 2022-04-30 23:59:59.999999999 | val | 1 | 51 | 0 | 31.129490 | 1 | 91.666667 | 1 | 0 | 0 | 0 | 0 | 0.0 | <NA> | 0 | B+ |
| 42491 | bd9a05c0-a603-4e20-8131-a3a5dff38653 | 2022-05-31 23:59:59.999999999 | val | 0 | 57 | 0 | 31.975702 | 1 | 99.333333 | 1 | 0 | 0 | 0 | 0 | 0.0 | <NA> | 0 | O+ |
Rows in unit_of_analysis_df_v2 where gender is empty: 0 Rows in unit_of_analysis_df_v2 where blood type is empty: 0
| RowID | ValidFrom | ValidTo | PatientGuid | Gender | DateOfBirth | StateCode | BloodType | |
|---|---|---|---|---|---|---|---|---|
| 3291 | 64d0d2ac-9586-40cf-b858-e755a7c4714b | 2022-05-22 | 2099-12-31 23:59:59 | 618ad33e-b907-4cb6-a0a9-4889ca284cf1 | F | 1970-07-17 | MN | B+ |
| 3305 | 31cfefaa-41a5-4f0e-9fc8-b3cb74bbd019 | 2022-06-23 | 2099-12-31 23:59:59 | bd9a05c0-a603-4e20-8131-a3a5dff38653 | F | 1964-11-11 | PA | O+ |
Rows in unit_of_analysis_df_v2 where statecode is empty:
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType |
|---|
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 14416 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2021-08-31 23:59:59.999999999 | train | 0 | 86 | 3 | 29.118343 | 1 | 93.333333 | 1 | 1 | 0 | 0 | 1 | 0.0 | FL | 0 | A+ |
| 19136 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2021-09-30 23:59:59.999999999 | train | 0 | 86 | 3 | 29.118343 | 1 | 93.333333 | 1 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 21315 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2021-10-31 23:59:59.999999999 | train | 0 | 86 | 3 | 29.118343 | 1 | 93.333333 | 1 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 24506 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2021-11-30 23:59:59.999999999 | train | 0 | 86 | 3 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 25357 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2021-12-31 23:59:59.999999999 | train | 0 | 86 | 3 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 30110 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-01-31 23:59:59.999999999 | train | 0 | 86 | 3 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 33002 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-02-28 23:59:59.999999999 | train | 0 | 86 | 3 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 0 | 0.0 | FL | 0 | A+ |
| 34720 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-03-31 23:59:59.999999999 | train | 0 | 86 | 3 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 37047 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-04-30 23:59:59.999999999 | train | 0 | 86 | 2 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 42264 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-05-31 23:59:59.999999999 | train | 0 | 86 | 2 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 42856 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-06-30 23:59:59.999999999 | train | 0 | 86 | 1 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 47433 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-07-31 23:59:59.999999999 | train | 0 | 87 | 1 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 51605 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-08-31 23:59:59.999999999 | train | 0 | 87 | 1 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 52560 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-09-30 23:59:59.999999999 | train | 0 | 87 | 0 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 56028 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-10-31 23:59:59.999999999 | train | 0 | 87 | 0 | 40.473399 | 0 | 100.666667 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 58896 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-11-30 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 63122 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2022-12-31 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 64852 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-01-31 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 69222 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-02-28 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 1 | 0 | 0 | 1 | 1.0 | AK | 0 | A+ |
| 73303 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-03-31 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 75014 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-04-30 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 77586 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-05-31 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 80511 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-06-30 23:59:59.999999999 | train | 0 | 87 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 85101 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-07-31 23:59:59.999999999 | train | 0 | 88 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 86478 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-08-31 23:59:59.999999999 | train | 0 | 88 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 89369 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-09-30 23:59:59.999999999 | train | 1 | 88 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 93596 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-10-31 23:59:59.999999999 | train | 1 | 88 | 0 | 41.620418 | 0 | 100.000000 | 0 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| 95094 | 018742a7-711a-471a-bf10-49bad6c163e0 | 2023-11-30 23:59:59.999999999 | train | 1 | 88 | 0 | 29.118343 | 1 | 93.333333 | 1 | 2 | 0 | 0 | 0 | 0.0 | AK | 0 | A+ |
| RowID | ValidFrom | ValidTo | PatientGuid | Gender | DateOfBirth | StateCode | BloodType | |
|---|---|---|---|---|---|---|---|---|
| 8 | 4c7e6947-10d4-451f-97dc-be438a90fe8b | 2020-01-01 | 2022-03-08 00:00:00 | 018742a7-711a-471a-bf10-49bad6c163e0 | F | 1935-07-23 | FL | A+ |
| 3264 | 330925e8-5ed0-48c0-9174-02d76cfecc80 | 2022-03-08 | 2099-12-31 23:59:59 | 018742a7-711a-471a-bf10-49bad6c163e0 | F | 1935-07-23 | AK | A+ |
From 'Pathology' Table¶
# Q2e Step 2
def add_most_recent_pathology_cluster(uoa: pd.DataFrame, labresults: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
"""
For each (PatientGuid, AsAtDate) in unit_of_analysis_df, compute:
- MostRecentPathologyReportCluster: cluster from the most recent pathology report not exceeding AsAtDate
"""
# Join labresults to get 'PatientGuid' and 'Timestamp'
reports = reports.merge(labresults[['LabResultGuid', 'PatientGuid', 'Timestamp']], on = 'LabResultGuid', how = 'left')
# Sort for per-patient binary searches
reports = reports.sort_values(['PatientGuid', 'Timestamp']).reset_index(drop = True)
uoa = uoa.sort_values(['PatientGuid', 'AsAtDate']).reset_index(drop = True)
# Prepare result array aligned to uoa rows
most_recent_clusters = [None] * len(uoa)
# Group both frames by patient for efficient windowed counts
reports_grp = {pid: g for pid, g in reports.groupby('PatientGuid', sort = False)}
uoa_grp = uoa.groupby('PatientGuid', sort = False)
# Loop patients (fast enough because snapshots are monthly)
for pid, u in uoa_grp:
if pid not in reports_grp:
continue
g = reports_grp[pid]
times = g['Timestamp'].to_numpy()
clusters = g['cluster'].to_numpy()
# For each AsAtDate, find indices of events in (S - 365, S]
idxs = u.index.to_numpy()
asofs = u['AsAtDate'].to_numpy()
for i, S in zip(idxs, asofs):
if pd.isna(S):
continue
# Right bound: last event <= S (inclusive)
R = np.searchsorted(times, S, side = 'right') - 1
if R >= 0:
# Get most recent cluster
most_recent_clusters[i] = int(clusters[R]) if pd.notna(clusters[R]) else None
# Attach results to original frame
# The index from the sorted uoa is used to align the results
out = uoa.copy()
out['MostRecentPathologyReportCluster'] = most_recent_clusters
# Preserve original row order
original_order_uoa = uoa.sort_index().copy()
out = out.set_index(original_order_uoa.index)
return out
# Attach cluster information to unit of analysis
unit_of_analysis_df_v2 = add_most_recent_pathology_cluster(unit_of_analysis_df_v2, labresult_df, pathology_df_cleaned)
# Convert column to nullable int data type
unit_of_analysis_df_v2['MostRecentPathologyReportCluster'] = unit_of_analysis_df_v2['MostRecentPathologyReportCluster'].astype('Int64')
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inclusion of clusters:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Display rows where cluster is missing
print(f"\nRows in unit_of_analysis_df_v2 where cluster is missing: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2['MostRecentPathologyReportCluster'].isna()])}")
# Sample check a patient at a particular as-at-date to verify that the cluster
# was correctly pulled
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2022-04-30').date())])
display(labresult_df[(labresult_df['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (labresult_df['Timestamp'].dt.date <= pd.to_datetime('2022-04-30').date())].sort_values(by = 'Timestamp'))
display(pathology_df_cleaned[pathology_df_cleaned['LabResultGuid'] == '7ee243d0-a158-46fa-8d0f-1ca401de2b77']['cluster'])
# Sample check a patient at a particular as-at-date to check on appropriateness
# of missing cluster
display(unit_of_analysis_df_v2[(unit_of_analysis_df_v2['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (unit_of_analysis_df_v2['AsAtDate'].dt.date == pd.to_datetime('2022-03-31').date())])
display(labresult_df[(labresult_df['PatientGuid'] == '00548e9d-e222-4498-accc-ea18f0fc0f49') & (labresult_df['Timestamp'].dt.date <= pd.to_datetime('2022-03-31').date())].sort_values(by = 'Timestamp'))
# Impute rows with -1 to signify no cluster
unit_of_analysis_df_v2['MostRecentPathologyReportCluster'] = unit_of_analysis_df_v2['MostRecentPathologyReportCluster'].fillna(-1).astype(int)
# Verify that no more rows have missing cluster
print(f"\nRows in unit_of_analysis_df_v2 where cluster is missing: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2['MostRecentPathologyReportCluster'].isna()])}")
Shape of unit_of_analysis_df before and after inclusion of clusters:
(98704, 19)
(98704, 4)
Rows in unit_of_analysis_df_v2 where cluster is missing: 35517
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType | MostRecentPathologyReportCluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 15 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2022-04-30 23:59:59.999999999 | test | 0 | 28 | 6 | 29.508642 | 1 | 102.166667 | 0 | 1 | 1 | 0 | 7 | 3.0 | MI | 1 | A+ | 1 |
| LabResultGuid | PatientGuid | Timestamp | tz_offset | |
|---|---|---|---|---|
| 4995 | 7ee243d0-a158-46fa-8d0f-1ca401de2b77 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2022-04-13 22:01:03 | -04:00 |
| cluster | |
|---|---|
| 0 | 1 |
| PatientGuid | AsAtDate | set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | NumberOfPrescriptions_p12m | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType | MostRecentPathologyReportCluster | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 14 | 00548e9d-e222-4498-accc-ea18f0fc0f49 | 2022-03-31 23:59:59.999999999 | test | 0 | 28 | 5 | 29.508642 | 1 | 102.166667 | 0 | 1 | 0 | 0 | 7 | 3.0 | MI | 1 | A+ | <NA> |
| LabResultGuid | PatientGuid | Timestamp | tz_offset |
|---|
Rows in unit_of_analysis_df_v2 where cluster is missing: 0
From 'StateDetails' Table¶
# Q2e Step 2
# Each observation in statedetails_df only have one validity period and all
# validity periods exceed our dataset period
def attach_state_attributes(uoa: pd.DataFrame, statedetails: pd.DataFrame) -> pd.DataFrame:
"""
Add state-level columns from statedetails_df_cleaned to unit_of_analysis_df.
"""
# Preserve original order
uoa['_row_ix'] = range(len(uoa))
# Left join on StateCode
merged = uoa.merge(statedetails[['StateCode', 'PopulationDensity', 'HospitalCountPer100k', 'BedCountPer100k', 'BelowPovertyLevel', 'Aged65Plus']], on = 'StateCode', how = 'left').sort_values('_row_ix')
# Drop helper columns
merged = merged.drop(columns = ['_row_ix'])
return merged
# Attach state-level information to unit of analysis
unit_of_analysis_df_v2 = attach_state_attributes(unit_of_analysis_df_v2, statedetails_df_cleaned)
# Confirm that the rows of the resulting dataframe has not changed
print("\nShape of unit_of_analysis_df before and after inclusion of state-level attributes:")
display(unit_of_analysis_df_v2.shape)
display(unit_of_analysis_df.shape)
# Display rows where state-level attributes are empty
print(f"\nRows in unit_of_analysis_df_v2 where state-level attributes are empty: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2['PopulationDensity'].isna() | unit_of_analysis_df_v2['HospitalCountPer100k'].isna() | unit_of_analysis_df_v2['BedCountPer100k'].isna() | unit_of_analysis_df_v2['BelowPovertyLevel'].isna() | unit_of_analysis_df_v2['Aged65Plus'].isna()])}")
# Sample check a patient to verify that the state-level attributes are correct
display(unit_of_analysis_df_v2[unit_of_analysis_df_v2['PatientGuid'] == 'ed1ca471-b7d6-4e3a-8e3c-189e25ddc434'][['PatientGuid', 'StateCode', 'PopulationDensity', 'HospitalCountPer100k', 'BedCountPer100k', 'BelowPovertyLevel', 'Aged65Plus']].agg(lambda x: x.unique()))
display(statedetails_df_cleaned[statedetails_df_cleaned['StateCode'] == 'TX'])
Shape of unit_of_analysis_df before and after inclusion of state-level attributes:
(98704, 24)
(98704, 4)
Rows in unit_of_analysis_df_v2 where state-level attributes are empty: 0
| PatientGuid | StateCode | PopulationDensity | HospitalCountPer100k | BedCountPer100k | BelowPovertyLevel | Aged65Plus | |
|---|---|---|---|---|---|---|---|
| 0 | ed1ca471-b7d6-4e3a-8e3c-189e25ddc434 | TX | 104.295842 | 1.317225 | 213.69024 | 0.1422 | 0.139 |
| StateGuid | StateCode | StateName | CentroidLatitude | CentroidLongitude | Area | CensusRegion | HospitalCount | HospitalBedCount | BelowPovertyLevel | Aged65Plus | TotalPopulation | ValidFrom | ValidTo | PopulationDensity | HospitalCountPer100k | BedCountPer100k | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 47 | d0acb213-3878-4469-ad87-b74e2f7ec7ad | TX | Texas | 31.106 | -97.6475 | 268596 | West South Central | 369 | 59862 | 0.1422 | 0.139 | 28013446 | 2010-01-01 | 2099-12-31 23:59:59 | 104.295842 | 1.317225 | 213.69024 |
Q2e Step 3 - Final Data Preparations¶
Check Missing Data¶
# Q2e Step 3
# Check for rows with any missing data
print(f"Rows in unit_of_analysis_df_v2 with any missing data: {len(unit_of_analysis_df_v2[unit_of_analysis_df_v2.isna().any(axis = 1)])}")
Rows in unit_of_analysis_df_v2 with any missing data: 0
Remove Unnecessary Columns and Duplicates¶
# Q2e Step 3
# Drop 'PatientGuid', 'AsAtDate' and remove duplicates
unit_of_analysis_df_v2 = unit_of_analysis_df_v2.drop(columns = ['PatientGuid', 'AsAtDate']).drop_duplicates()
# Display the head of unit_of_analysis_df_v2 to verify that the columns have been dropped
display(unit_of_analysis_df_v2.head(1))
# Check for any duplicate rows in unit_of_analysis_df_v2
print(f"Duplicate rows in unit_of_analysis_df_v2: {unit_of_analysis_df_v2.duplicated().sum()}")
| set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | ... | MedianNumberOfRefills_p12m | StateCode | Gender | BloodType | MostRecentPathologyReportCluster | PopulationDensity | HospitalCountPer100k | BedCountPer100k | BelowPovertyLevel | Aged65Plus | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | test | 0 | 26 | 2 | 25.426717 | 0 | 102.5 | 0 | 1 | 0 | ... | 0.0 | MI | 1 | A+ | -1 | 100.850361 | 1.045774 | 234.673746 | 0.1371 | 0.187 |
1 rows × 22 columns
Duplicate rows in unit_of_analysis_df_v2: 0
One-Hot Encoding of Categorical Features¶
# Q2e Step 3
# Columns to one-hot encode
columns_to_encode = ['MostRecentPathologyReportCluster', 'StateCode', 'BloodType']
# Perform one-hot encoding
unit_of_analysis_df_v2 = pd.get_dummies(unit_of_analysis_df_v2, columns = columns_to_encode, drop_first = True)
# Convert boolean columns to 1/0
for col in [col for col in unit_of_analysis_df_v2.columns if unit_of_analysis_df_v2[col].dtype == 'bool']:
unit_of_analysis_df_v2[col] = unit_of_analysis_df_v2[col].astype(int)
# Check data types in unit_of_analysis_df_v2
print(unit_of_analysis_df_v2.dtypes)
# Display the head of unit_of_analysis_df_v2 to verify
display(unit_of_analysis_df_v2.head(1))
set object
AcuteDiagnosis int64
AgeLastBirthday int64
unique_ICD9CodesDesc_p12m int32
MedianBMI_p12m float64
...
BloodType_AB- int64
BloodType_B+ int64
BloodType_B- int64
BloodType_O+ int64
BloodType_O- int64
Length: 82, dtype: object
| set | AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | ... | StateCode_WI | StateCode_WV | StateCode_WY | BloodType_A- | BloodType_AB+ | BloodType_AB- | BloodType_B+ | BloodType_B- | BloodType_O+ | BloodType_O- | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | test | 0 | 26 | 2 | 25.426717 | 0 | 102.5 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 rows × 82 columns
Separation of Data into Separate Training, Validation, and Test Set DataFrames¶
# Q2e Step 3
# Separate out training, validation, and testing sets
training_df = unit_of_analysis_df_v2[unit_of_analysis_df_v2['set'] == 'train']
val_df = unit_of_analysis_df_v2[unit_of_analysis_df_v2['set'] == 'val']
test_df = unit_of_analysis_df_v2[unit_of_analysis_df_v2['set'] == 'test']
# Drop 'set' column from training_df, val_df, and test_df
for df in [training_df, val_df, test_df]:
df.drop(columns = ['set'], inplace = True)
Standardisation of Numeric Features¶
# Q2e Step 3
# List of columns to standardise
columns_to_standardise = [
'AgeLastBirthday',
'unique_ICD9CodesDesc_p12m',
'MedianBMI_p12m',
'MedianMAP_p12m',
'unique_PhysicianSpecialtyGroups_p12m',
'NumberOfAbnormalLabResults_p12m',
'NumberOfPrescriptions_p12m',
'MedianNumberOfRefills_p12m',
'PopulationDensity',
'HospitalCountPer100k',
'BedCountPer100k',
'BelowPovertyLevel',
'Aged65Plus']
# Initialise the StandardScaler and fit only on training data
scaler = StandardScaler()
scaler.fit(training_df[columns_to_standardise])
# Apply standard scaling to the specified columns within each of the training,
# validation, and testing sets
for df in [training_df, val_df, test_df]:
df[columns_to_standardise] = scaler.transform(df[columns_to_standardise])
# Display the head of the scaled DataFrames to verify reasonableness
display(training_df.head(1))
display(val_df.head(1))
display(test_df.head(1))
| AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | ... | StateCode_WI | StateCode_WV | StateCode_WY | BloodType_A- | BloodType_AB+ | BloodType_AB- | BloodType_B+ | BloodType_B- | BloodType_O+ | BloodType_O- | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 108 | 0 | -1.4249 | -0.925884 | 1.944609 | 0 | 1.13649 | 0 | -0.075505 | 0.204854 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
1 rows × 81 columns
| AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | ... | StateCode_WI | StateCode_WV | StateCode_WY | BloodType_A- | BloodType_AB+ | BloodType_AB- | BloodType_B+ | BloodType_B- | BloodType_O+ | BloodType_O- | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 72 | 1 | 0.772383 | -0.342948 | -0.034804 | 1 | -0.062636 | 1 | -0.075505 | 3.10382 | 0 | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
1 rows × 81 columns
| AcuteDiagnosis | AgeLastBirthday | unique_ICD9CodesDesc_p12m | MedianBMI_p12m | MedianBMI_p12m_missing | MedianMAP_p12m | MedianMAP_p12m_missing | unique_PhysicianSpecialtyGroups_p12m | NumberOfAbnormalLabResults_p12m | SmokerStatus | ... | StateCode_WI | StateCode_WV | StateCode_WY | BloodType_A- | BloodType_AB+ | BloodType_AB- | BloodType_B+ | BloodType_B- | BloodType_O+ | BloodType_O- | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | -1.233832 | 0.239989 | -0.745067 | 0 | 1.036563 | 0 | -0.075505 | -0.761468 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 rows × 81 columns
Separation of Features and Response Variable¶
# Q2e Step 3
X_training = training_df.drop('AcuteDiagnosis', axis = 1)
Y_training = training_df['AcuteDiagnosis']
X_val = val_df.drop('AcuteDiagnosis', axis = 1)
Y_val = val_df['AcuteDiagnosis']
X_test = test_df.drop('AcuteDiagnosis', axis = 1)
Y_test = test_df['AcuteDiagnosis']
# Verify structure of features and response variables
display(X_training.shape)
display(Y_training.shape)
display(X_val.shape)
display(Y_val.shape)
display(X_test.shape)
display(Y_test.shape)
(45374, 80)
(45374,)
(8706, 80)
(8706,)
(9749, 80)
(9749,)
Savepoint¶
X_training.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_training.pkl')
Y_training.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_training.pkl')
X_val.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_val.pkl')
Y_val.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_val.pkl')
X_test.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_test.pkl')
Y_test.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_test.pkl')
Loadpoint¶
X_training = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_training.pkl')
Y_training = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_training.pkl')
X_val = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_val.pkl')
Y_val = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_val.pkl')
X_test = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_test.pkl')
Y_test = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/Y_test.pkl')
Q2e Step 4 - Helper Functions¶
Plot Model Loss¶
# Q2e Step 4
def plot_training_val_loss(history, title = "Training vs. Validation Loss"):
"""
Plots training and validation loss over epochs.
Parameters:
history : Keras History object (e.g., model.fit(...))
title : Title for the plot
"""
plt.figure(figsize = (8, 5))
plt.plot(history.history['loss'], label = 'Training')
plt.plot(history.history['val_loss'], label = 'Validation')
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc = 'upper right')
plt.grid(True)
plt.tight_layout()
plt.show()
Success Metrics¶
# Q2e Step 4
# Business weights (i.e. model-agnostic)
# Define W_FN and W_FP before use in function definitions
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
def cost_sensitive_error(y_true, y_pred, w_fn=W_FN, w_fp=W_FP):
y_true = np.asarray(y_true).astype(int)
y_pred = np.asarray(y_pred).astype(int)
FN = np.sum((y_true == 1) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
N = len(y_true)
return (w_fn * FN + w_fp * FP) / N
def confusion_matrix_df(y_true, y_prob, threshold):
"""
Build a confusion matrix DataFrame at a given probability threshold.
Parameters
----------
y_true : array-like of shape (n,)
Ground truth labels in {0,1}.
y_prob : array-like of shape (n,)
Predicted probabilities in [0,1].
threshold : float
Classify as 1 when y_prob >= threshold, else 0.
Returns
-------
cm : pd.DataFrame
Confusion matrix with rows = Actual {0,1}, cols = Pred {0,1}.
"""
y_true = np.asarray(y_true).astype(int).ravel()
y_prob = np.asarray(y_prob).astype(float).ravel()
y_pred = (y_prob >= threshold).astype(int)
TN = int(np.sum((y_true == 0) & (y_pred == 0)))
FP = int(np.sum((y_true == 0) & (y_pred == 1)))
FN = int(np.sum((y_true == 1) & (y_pred == 0)))
TP = int(np.sum((y_true == 1) & (y_pred == 1)))
cm = pd.DataFrame(
[[TN, FP],
[FN, TP]],
index=pd.Index(["Actual 0", "Actual 1"], name=""),
columns=pd.Index(["Predicted 0", "Predicted 1"], name="")
)
return cm
def confusion_summary(y_true, y_prob, threshold):
"""
Compute key classification metrics at a fixed threshold.
Returns a one-row DataFrame with counts + rates.
"""
y_true = np.asarray(y_true).astype(int).ravel()
y_prob = np.asarray(y_prob).astype(float).ravel()
y_pred = (y_prob >= threshold).astype(int)
TN = int(np.sum((y_true == 0) & (y_pred == 0)))
FP = int(np.sum((y_true == 0) & (y_pred == 1)))
FN = int(np.sum((y_true == 1) & (y_pred == 0)))
TP = int(np.sum((y_true == 1) & (y_pred == 1)))
N = len(y_true)
# Safeguard divisions
def safe_div(a, b):
return float(a) / float(b) if b else 0.0
recall = safe_div(TP, TP + FN)
precision = safe_div(TP, TP + FP)
return pd.DataFrame([{
"N": N,
"Threshold": float(threshold),
"TN": TN, "FP": FP, "FN": FN, "TP": TP,
"Recall": recall,
"Precision": precision
}])
def plot_confusion_matrix(cm_df, title="Confusion matrix"):
"""
Simple heatmap-style plot from the DataFrame returned by confusion_matrix_df.
Single-axes matplotlib plot.
"""
fig, ax = plt.subplots(figsize=(4.5, 4))
im = ax.imshow(cm_df.values, aspect='auto')
ax.set_xticks(range(cm_df.shape[1]))
ax.set_xticklabels(cm_df.columns)
ax.set_yticks(range(cm_df.shape[0]))
ax.set_yticklabels(cm_df.index)
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
ax.set_title(title)
# annotate counts
for i in range(cm_df.shape[0]):
for j in range(cm_df.shape[1]):
ax.text(j, i, f"{int(cm_df.iloc[i, j])}", ha='center', va='center')
plt.tight_layout()
return ax
def choose_threshold(
y_true, y_prob,
w_fn=W_FN, w_fp=W_FP,
plot=False, return_curve=False, ax=None,
overlay_precision_recall=True,
plot_roc=False, roc_ax=None, return_roc_curve=False, model_name=None
):
"""
Pick a single threshold that minimises cost-sensitive misclassification on validation.
Optional:
- plot cost vs threshold (and overlay precision/recall on a twin y-axis)
- plot the ROC curve (separate figure)
- return the full cost-tuning curve DataFrame
- return the ROC curve DataFrame (+ AUC)
Returns
-------
thr_star : float
Chosen threshold (policy cut).
cost_star : float
Minimum cost-sensitive error at thr_star.
curve_df : pandas.DataFrame (optional if return_curve=True)
Columns: ['k','threshold','TP','FP','FN','cost','precision','recall'].
roc_df, auc : (optional if return_roc_curve=True)
roc_df columns: ['fpr','tpr','threshold'] and AUC float.
"""
y = np.asarray(y_true).astype(int)
p = np.asarray(y_prob).astype(float)
n = len(y)
if n == 0:
thr_star, cost_star = 1.0, 0.0
outs = (thr_star, cost_star)
if return_curve:
outs = outs + (pd.DataFrame(),)
if return_roc_curve:
outs = outs + (pd.DataFrame(columns=["fpr","tpr","threshold"]), np.nan)
return outs
# -------- Cost curve mechanics --------
# Sort by score descending
order = np.argsort(-p)
y_sorted = y[order]
p_sorted = p[order]
# Cumulative TP when taking top-k as positives
tp_cum = np.cumsum(y_sorted == 1)
P = tp_cum[-1] # total positives
ks = np.arange(0, n+1) # number predicted positive
TP = np.concatenate(([0], tp_cum))
FP = ks - TP
FN = P - TP
cost = (w_fn * FN + w_fp * FP) / max(n, 1)
# Map each k to a threshold that yields exactly k predicted positives
thresholds = np.empty(n+1, dtype=float)
thresholds[0] = np.nextafter(p_sorted[0], 1.0) # predict none positive
thresholds[n] = np.nextafter(p_sorted[-1], -1.0) # predict all positive
if n > 1:
mids = 0.5 * (p_sorted[:-1] + p_sorted[1:])
thresholds[1:n] = mids
# Precision/Recall for completeness (threshold-dependent)
with np.errstate(divide='ignore', invalid='ignore'):
precision = np.where(ks == 0, 0.0, TP / ks)
recall = np.where(P == 0, 0.0, TP / P)
# Choose k★ that minimises cost; convert to thr_star
k_star = int(np.argmin(cost))
thr_star = float(thresholds[k_star])
cost_star = float(cost[k_star])
# Optional cost/precision/recall plot (single axes with twin y-axis)
if plot:
if ax is None:
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(thresholds, cost, label="Cost-sensitive error")
ax.axvline(thr_star, linestyle="--", label=f"Optimal Threshold = {thr_star:.4f}")
ax.set_xlabel("Threshold")
ax.set_ylabel("Cost-Sensitive Error")
ax.set_title(f"Threshold Tuning on Validation Set ({model_name})")
if overlay_precision_recall:
ax2 = ax.twinx()
ax2.plot(thresholds, precision, linestyle="--", label="Precision")
ax2.plot(thresholds, recall, linestyle="-.", label="Recall")
ax2.set_ylabel("Precision / Recall")
# Merge legends
l1, lab1 = ax.get_legend_handles_labels()
l2, lab2 = ax2.get_legend_handles_labels()
ax.legend(l1 + l2, lab1 + lab2, loc="best")
else:
ax.legend(loc="best")
plt.tight_layout()
# -------- ROC curve mechanics --------
roc_df = None
auc = np.nan
if plot_roc or return_roc_curve:
try:
fpr, tpr, thr_roc = roc_curve(y, p)
auc = roc_auc_score(y, p)
if plot_roc:
if roc_ax is None:
fig2, roc_ax = plt.subplots(figsize=(5.5, 5.5))
roc_ax.plot(fpr, tpr, label="ROC")
roc_ax.plot([0, 1], [0, 1], linestyle="--", label="Random Chance")
roc_ax.set_xlabel("False Positive Rate")
roc_ax.set_ylabel("True Positive Rate")
roc_ax.set_title(f"ROC Curve (AUC = {auc:.3f}) ({model_name})")
roc_ax.legend(loc="lower right")
plt.tight_layout()
if return_roc_curve:
roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "threshold": thr_roc})
except ValueError:
# Only one class present; keep auc=nan and optional empty roc_df
if return_roc_curve:
roc_df = pd.DataFrame(columns=["fpr","tpr","threshold"])
# Build return tuple
outs = (thr_star, cost_star)
if return_curve:
curve_df = pd.DataFrame({
"k": ks,
"threshold": thresholds,
"TP": TP.astype(int),
"FP": FP.astype(int),
"FN": FN.astype(int),
"cost": cost,
"precision": precision,
"recall": recall
})
outs = outs + (curve_df,)
if return_roc_curve:
outs = outs + (roc_df, auc)
return outs
def eval_split(y_true, y_prob, threshold):
"""
Compute split-level metrics at the fixed policy threshold, plus threshold-free ROC–AUC.
"""
y_true = np.asarray(y_true).astype(int)
y_prob = np.asarray(y_prob).astype(float)
y_pred = (y_prob >= threshold).astype(int)
# ROC–AUC requires both classes present; return NaN otherwise.
try:
auc = roc_auc_score(y_true, y_prob)
except ValueError:
auc = np.nan
metrics = {
"N": len(y_true),
"Prevalence": float(np.mean(y_true)) if len(y_true) else np.nan,
"Threshold": float(threshold),
"Cost_sens_error": cost_sensitive_error(y_true, y_pred),
"Sensitivity": recall_score(y_true, y_pred, zero_division=0),
"Precision": precision_score(y_true, y_pred, zero_division=0),
"ROC_AUC": auc,
}
return metrics
def plot_metrics_test(
y_true, y_prob,
w_fn=W_FN, w_fp=W_FP,
threshold=None, # <-- NEW: set your fixed policy cut here (float). If None, picks the cost-minimising cut.
plot=False, return_curve=False, ax=None,
overlay_precision_recall=True,
plot_roc=False, roc_ax=None, return_roc_curve=False,
title_cost=None, title_roc=None, model_name=None
):
"""
If `threshold` is provided:
• Compute and RETURN the cost at that fixed threshold (no optimisation).
• Plot cost/precision/recall vs threshold and draw a vertical dotted line at your threshold.
If `threshold` is None:
• Behaves like the earlier version: chooses the threshold that minimises cost on y_true/y_prob.
Returns
-------
thr_out : float
The fixed threshold you passed (or the cost-minimising threshold if none provided).
cost_out : float
Misclassification cost at thr_out.
curve_df : pd.DataFrame (optional if return_curve=True)
Columns: ['k','threshold','TP','FP','FN','cost','precision','recall'] (threshold grid for the plot).
roc_df, auc : (optional if return_roc_curve=True)
ROC curve data and AUC for the same split.
"""
y = np.asarray(y_true).astype(int).ravel()
p = np.asarray(y_prob).astype(float).ravel()
n = len(y)
if n == 0:
thr_out, cost_out = (1.0, 0.0)
outs = (thr_out, cost_out)
if return_curve:
outs += (pd.DataFrame(),)
if return_roc_curve:
outs += (pd.DataFrame(columns=["fpr","tpr","threshold"]), np.nan)
return outs
# ---------- Build the threshold grid for plotting ----------
order = np.argsort(-p) # desc
y_sorted = y[order]
p_sorted = p[order]
tp_cum = np.cumsum(y_sorted == 1)
P = int(tp_cum[-1])
ks = np.arange(0, n+1) # number predicted positive
TP = np.concatenate(([0], tp_cum))
FP = ks - TP
FN = P - TP
cost = (w_fn * FN + w_fp * FP) / max(n, 1)
thresholds_grid = np.empty(n+1, dtype=float)
thresholds_grid[0] = np.nextafter(p_sorted[0], 1.0) # predict none positive
thresholds_grid[n] = np.nextafter(p_sorted[-1], -1.0) # predict all positive
if n > 1:
mids = 0.5 * (p_sorted[:-1] + p_sorted[1:])
thresholds_grid[1:n] = mids
# Precision / Recall along the grid (for overlays)
with np.errstate(divide='ignore', invalid='ignore'):
precision_grid = np.where(ks == 0, 0.0, TP / ks)
recall_grid = np.where(P == 0, 0.0, TP / P)
# ---------- Decide which threshold to mark & what cost to return ----------
if threshold is None:
# Choose cost-minimising threshold (backward-compat)
k_star = int(np.argmin(cost))
thr_out = float(thresholds_grid[k_star])
cost_out = float(cost[k_star])
vline_label = f"Chosen threshold (min cost) = {thr_out:.4f}"
else:
# Use caller-specified fixed threshold
thr_out = float(threshold)
y_pred_fixed = (p >= thr_out).astype(int)
FN_fixed = int(np.sum((y == 1) & (y_pred_fixed == 0)))
FP_fixed = int(np.sum((y == 0) & (y_pred_fixed == 1)))
cost_out = (w_fn * FN_fixed + w_fp * FP_fixed) / max(n, 1)
vline_label = f"Chosen threshold = {thr_out:.2f}"
# ---------- Plot cost with precision/recall overlays (single axes + twin y) ----------
if plot:
# Ensure the x-limits include the fixed threshold even if outside grid
x_min = float(min(thresholds_grid.min(), thr_out))
x_max = float(max(thresholds_grid.max(), thr_out))
if ax is None:
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(thresholds_grid, cost, label="Cost-sensitive error")
ax.axvline(thr_out, linestyle="--", label=vline_label)
ax.set_xlim(x_min, x_max)
ax.set_xlabel("Threshold")
ax.set_ylabel("Cost-sensitive error")
ax.set_title(title_cost or "Validation/Test threshold view (cost + precision/recall)")
if overlay_precision_recall:
ax2 = ax.twinx()
ax2.plot(thresholds_grid, precision_grid, linestyle="--", label="Precision")
ax2.plot(thresholds_grid, recall_grid, linestyle="-.", label="Recall")
ax2.set_ylabel("Precision / Recall")
l1, lab1 = ax.get_legend_handles_labels()
l2, lab2 = ax2.get_legend_handles_labels()
ax.legend(l1 + l2, lab1 + lab2, loc="best")
else:
ax.legend(loc="best")
plt.tight_layout()
# ---------- Optional ROC curve ----------
roc_df = None
auc = np.nan
if plot_roc or return_roc_curve:
try:
fpr, tpr, thr_roc = roc_curve(y, p)
auc = roc_auc_score(y, p)
if plot_roc:
if roc_ax is None:
fig2, roc_ax = plt.subplots(figsize=(5.5, 5.5))
roc_ax.plot(fpr, tpr, label="ROC")
roc_ax.plot([0, 1], [0, 1], linestyle="--", label="Random Chance")
roc_ax.set_xlabel("False Positive Rate")
roc_ax.set_ylabel("True Positive Rate")
roc_ax.set_title(f"ROC Curve (AUC = {auc:.3f}) on Test Data ({model_name})")
roc_ax.legend(loc="lower right")
plt.tight_layout()
if return_roc_curve:
roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "threshold": thr_roc})
except ValueError:
if return_roc_curve:
roc_df = pd.DataFrame(columns=["fpr","tpr","threshold"])
# ---------- Build return tuple ----------
outs = (thr_out, float(cost_out))
if return_curve:
curve_df = pd.DataFrame({
"k": ks,
"threshold": thresholds_grid,
"TP": TP.astype(int),
"FP": FP.astype(int),
"FN": FN.astype(int),
"cost": cost,
"precision": precision_grid,
"recall": recall_grid
})
outs += (curve_df,)
if return_roc_curve:
outs += (roc_df, auc)
return outs
Plot Decile Chart¶
def plot_decile_chart(model, X, y_true, model_name=None):
"""
Plots a decile chart showing observed acute probability by predicted risk decile.
Parameters:
model : Trained model with .predict()
X : Feature set
y_true : True binary labels
title : Plot title
"""
# Step 1: Predict probabilities
y_prob = model.predict(X).ravel()
# Step 2: Create DataFrame
df_pred = pd.DataFrame({"y_prob": y_prob, "y_true": y_true})
# Step 3: Assign deciles (0 = highest risk)
df_pred["decile"] = pd.qcut(df_pred["y_prob"], 10, labels=False, duplicates="drop")
df_pred["decile"] = 9 - df_pred["decile"]
# Step 4: Compute observed acute probabilities per decile
decile_summary = df_pred.groupby("decile")["y_true"].mean().reset_index()
# Step 5: Plot
plt.figure(figsize=(8, 5))
plt.bar(decile_summary["decile"], decile_summary["y_true"], color='skyblue')
plt.xlabel("Decile (0 = Highest Risk)")
plt.ylabel("Observed Acute Probability")
plt.title(f"Decile Plot of Acute Probabilities by Predicted Risk ({model_name})")
plt.xticks(decile_summary["decile"])
plt.grid(axis='y')
plt.tight_layout()
plt.show()
Plot Actual vs Expected¶
# Function to generate actual vs expected plot
def plot_actual_vs_expected(y_true, y_prob, n_bins=10, model_name=None):
y_true = np.asarray(y_true).astype(int)
y_prob = np.asarray(y_prob).astype(float)
bins = np.linspace(0, 1, n_bins+1)
idx = np.clip(np.digitize(y_prob, bins) - 1, 0, n_bins-1)
df = pd.DataFrame({"bin": idx, "y": y_true, "p": y_prob})
g = df.groupby("bin").agg(
p_mean=("p","mean"),
y_rate=("y","mean"),
n=("y","size")
).dropna()
plt.plot([0,1],[0,1], linestyle="--", label="Perfect Calibration")
plt.plot(g["p_mean"], g["y_rate"], marker="o", label=f"{model_name}")
for i,(xm,ym,nn) in enumerate(zip(g["p_mean"], g["y_rate"], g["n"])):
plt.annotate(str(int(nn)), (xm, ym), textcoords="offset points", xytext=(0,6), ha="center", fontsize=8)
plt.xlabel("Mean Predicted Probability (per bin)")
plt.ylabel("Observed Acute Probability (per bin)")
plt.title(f"Plot of Actual vs Expected ({model_name})")
plt.legend(loc="best")
plt.tight_layout()
plt.show()
Set Deterministic Nature¶
# Q2e Step 4
# Ensure reproducibility by setting seeds
def keras_deterministic(seed_value = 0):
# 1. Set 'PYTHONHASHSEED' environment variable at a fixed value
os.environ['PYTHONHASHSEED'] = str(seed_value)
# 2. Set 'python' built-in pseudo-random generator at a fixed value
random.seed(seed_value)
# 3. Set 'numpy' pseudo-random generator at a fixed value
np.random.seed(seed_value)
# 4. Set the 'tensorflow' pseudo-random generator at a fixed value
tf.random.set_seed(seed_value)
Q2e Step 5 - Starting Model (Iteration 1)¶
# Q2e Step 5: Model construction
# Set seed for reproducible results
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Define the initial neural network architecture
nn1 = keras.Sequential([
layers.Input(shape = (X_training.shape[1],)),
# First hidden layer
layers.Dense(64, activation = 'relu'),
# Regularisation with dropout
layers.Dropout(0.3),
# Second hidden layer
layers.Dense(32, activation = 'tanh'),
# Regularisation with dropout
layers.Dropout(0.3),
# Output layer for binary classification
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn1.compile(
optimizer = keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_1 = nn1.fit(
X_training, Y_training,
validation_data = (X_val, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
verbose = 1
)
# Check model parameters
display(nn1.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.5671 - loss: 0.5766 - precision: 0.3339 - recall: 0.0335 - val_auc: 0.5969 - val_loss: 0.5674 - val_precision: 1.0000 - val_recall: 8.7222e-04 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 994us/step - auc: 0.6258 - loss: 0.5515 - precision: 0.4231 - recall: 0.0112 - val_auc: 0.5818 - val_loss: 0.5740 - val_precision: 0.2692 - val_recall: 0.0031 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - auc: 0.6561 - loss: 0.5398 - precision: 0.5907 - recall: 0.0370 - val_auc: 0.5770 - val_loss: 0.5803 - val_precision: 0.3373 - val_recall: 0.0122 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 955us/step - auc: 0.6793 - loss: 0.5301 - precision: 0.5953 - recall: 0.0743 - val_auc: 0.5686 - val_loss: 0.5882 - val_precision: 0.3510 - val_recall: 0.0231 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 932us/step - auc: 0.6953 - loss: 0.5223 - precision: 0.5891 - recall: 0.1102 - val_auc: 0.5670 - val_loss: 0.5940 - val_precision: 0.3483 - val_recall: 0.0406 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 931us/step - auc: 0.7109 - loss: 0.5146 - precision: 0.5961 - recall: 0.1372 - val_auc: 0.5626 - val_loss: 0.6016 - val_precision: 0.3675 - val_recall: 0.0532
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 64) │ 5,184 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 32) │ 2,080 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 21,893 (85.52 KB)
Trainable params: 7,297 (28.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 14,596 (57.02 KB)
None
# Q2e Step 5: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_1, title = 'Training vs. Validation Loss (NN 1)')
p_va = nn1.predict(X_val, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.4)
sum_va = confusion_summary(y_va, p_va, 0.4)
print("\nConfusion Matrix Summary (0.4 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.4 threshold) (NN 1)')
plt.show()
Confusion Matrix Summary (0.4 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.4 | 6180 | 233 | 2123 | 170 | 0.074139 | 0.421836 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 1'
)
threshold = 0.4
y_val_probs = nn1.predict(X_val, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5145 | Optimal threshold: 0.4020 | Optimal Misclassification Cost: 0.5134 | Validation AUC: 0.597
# Plot decile chart
plot_decile_chart(nn1, X_val, Y_val, model_name = 'NN 1')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a DataFrame to store key metrics for each model
model_metrics_df = pd.DataFrame(columns = [
'Model Name',
'Threshold',
'Validation Loss',
'Misclassification Cost',
'Recall',
'Precision',
'ROC_AUC',
'Optimal Threshold',
'Optimal Misclassification Cost'
])
# Create a new row with the metrics for NN 1
nn1_metrics = {
'Model Name': 'NN 1',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_1.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn1_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
Q2e Step 6 - Adjust Class Imbalance (Iteration 2)¶
# Q2e Step 6: Model construction
# Get number of 0s and 1s in Y_training
print(f"Number of 0s in Y_training: {np.sum(Y_training == 0)}")
print(f"Number of 1s in Y_training: {np.sum(Y_training == 1)}")
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Rebalance class weights
class_weights = class_weight.compute_class_weight(
class_weight = 'balanced',
classes = np.unique(Y_training),
y = Y_training
)
class_weight_dict = dict(enumerate(class_weights))
# Refine the neural network architecture
nn2 = keras.Sequential([
layers.Input(shape = (X_training.shape[1],)),
layers.Dense(64, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(32, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn2.compile(
optimizer = keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_2 = nn2.fit(
X_training, Y_training,
validation_data = (X_val, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# Insert rebalanced class weights
class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn2.summary())
Number of 0s in Y_training: 33886 Number of 1s in Y_training: 11488 Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.5684 - loss: 0.6957 - precision: 0.2927 - recall: 0.5581 - val_auc: 0.5987 - val_loss: 0.6740 - val_precision: 0.3282 - val_recall: 0.6066 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 990us/step - auc: 0.6297 - loss: 0.6702 - precision: 0.3306 - recall: 0.6352 - val_auc: 0.5801 - val_loss: 0.6823 - val_precision: 0.3168 - val_recall: 0.5831 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 986us/step - auc: 0.6589 - loss: 0.6555 - precision: 0.3490 - recall: 0.6575 - val_auc: 0.5743 - val_loss: 0.6808 - val_precision: 0.3149 - val_recall: 0.5556 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 983us/step - auc: 0.6837 - loss: 0.6419 - precision: 0.3648 - recall: 0.6785 - val_auc: 0.5668 - val_loss: 0.6879 - val_precision: 0.3037 - val_recall: 0.5229 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 952us/step - auc: 0.7001 - loss: 0.6318 - precision: 0.3729 - recall: 0.6818 - val_auc: 0.5641 - val_loss: 0.6857 - val_precision: 0.3083 - val_recall: 0.5041 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 974us/step - auc: 0.7176 - loss: 0.6199 - precision: 0.3911 - recall: 0.6952 - val_auc: 0.5528 - val_loss: 0.6934 - val_precision: 0.3043 - val_recall: 0.4727
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 64) │ 5,184 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 32) │ 2,080 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 21,893 (85.52 KB)
Trainable params: 7,297 (28.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 14,596 (57.02 KB)
None
# Q2e Step 6: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_2, title = 'Training vs. Validation Loss (NN 2)')
p_va = nn2.predict(X_val, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.4)
sum_va = confusion_summary(y_va, p_va, 0.4)
print("\nConfusion Matrix Summary (0.4 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.4 threshold) (NN 2)')
plt.show()
Confusion Matrix Summary (0.4 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.4 | 1726 | 4687 | 377 | 1916 | 0.835587 | 0.290171 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 2'
)
threshold = 0.4
y_val_probs = nn2.predict(X_val, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.6250 | Optimal threshold: 0.6623 | Optimal Misclassification Cost: 0.5090 | Validation AUC: 0.599
# Plot decile chart
plot_decile_chart(nn2, X_val, Y_val, model_name = 'NN 2')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 2
nn2_metrics = {
'Model Name': 'NN 2',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_2.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn2_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
Q2e Step 7 - Feature Selection (Iteration 3)¶
# Q2e Step 7: Model construction
# Select a small sample of data (roughly ~5% of training set) to speed up SHAP calculation
X_sample = X_training.sample(n = 2500, random_state = 0).values.astype(np.float32)
# Explainer for TensorFlow/Keras model
explainer = shap.Explainer(nn2, X_sample)
# Compute SHAP values
shap_values = explainer(X_sample)
# Summary plot
shap.summary_plot(shap_values, features = X_sample, feature_names = X_training.columns)
# Save SHAP feature importance in a DataFrame
shap_importance_df = pd.DataFrame({
'Feature': X_training.columns,
'Mean SHAP Value': np.abs(shap_values.values).mean(axis = 0)
}).sort_values(by = 'Mean SHAP Value', ascending = False).reset_index(drop = True)
results = []
# Test models with features from 10 to 55 in increments of 5 (e.g. 10, 15, ...)
feature_counts = list(range(10, min(56, len(shap_importance_df) + 1), 5))
for k in feature_counts:
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
selected_features = shap_importance_df['Feature'].head(k).tolist()
X_training_k = X_training[selected_features]
X_val_k = X_val[selected_features]
# Define model
model = tf.keras.Sequential([
layers.Input(shape = (X_training_k.shape[1],)),
layers.Dense(64, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(32, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
history = model.fit(
X_training_k, Y_training,
validation_data = (X_val_k, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 0
)
# Get final validation loss
val_loss = min(history.history['val_loss'])
results.append((k, val_loss))
# Plot Validation Loss vs. Number of Features
ks, val_losses = zip(*results)
plt.figure(figsize = (8, 4))
plt.plot(ks, val_losses, marker = 'o')
plt.xlabel('Number of SHAP-Selected Features')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs. Feature Count')
plt.grid(True)
plt.tight_layout()
plt.show()
# Take the top 20 features
top_20_features = shap_importance_df['Feature'].head(20).tolist()
X_training_best = X_training[top_20_features]
X_val_best = X_val[top_20_features]
X_test_best = X_test[top_20_features]
PermutationExplainer explainer: 2501it [00:59, 39.55it/s]
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Refine the neural network architecture
nn3 = keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(64, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(32, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn3.compile(
optimizer=keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_3 = nn3.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn3.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.5578 - loss: 0.5824 - precision: 0.2972 - recall: 0.0439 - val_auc: 0.6146 - val_loss: 0.5607 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - auc: 0.6007 - loss: 0.5591 - precision: 0.4118 - recall: 0.0054 - val_auc: 0.6097 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - auc: 0.6191 - loss: 0.5532 - precision: 0.3136 - recall: 0.0033 - val_auc: 0.6034 - val_loss: 0.5641 - val_precision: 0.8182 - val_recall: 0.0039 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - auc: 0.6289 - loss: 0.5498 - precision: 0.5687 - recall: 0.0122 - val_auc: 0.6056 - val_loss: 0.5635 - val_precision: 0.7727 - val_recall: 0.0074 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 988us/step - auc: 0.6352 - loss: 0.5473 - precision: 0.6110 - recall: 0.0200 - val_auc: 0.6034 - val_loss: 0.5652 - val_precision: 0.5588 - val_recall: 0.0166 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 997us/step - auc: 0.6434 - loss: 0.5451 - precision: 0.5300 - recall: 0.0266 - val_auc: 0.6042 - val_loss: 0.5658 - val_precision: 0.4667 - val_recall: 0.0214
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 64) │ 1,344 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 32) │ 2,080 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 10,373 (40.52 KB)
Trainable params: 3,457 (13.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 6,916 (27.02 KB)
None
# Q2e Step 7: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_3, title = 'Training vs. Validation Loss (NN 3)')
p_va = nn3.predict(X_val_best, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.2)
sum_va = confusion_summary(y_va, p_va, 0.2)
print("\nConfusion Matrix Summary (0.2 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.2 threshold) (NN 3)')
plt.show()
Confusion Matrix Summary (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.2 | 2019 | 4394 | 406 | 1887 | 0.822939 | 0.30043 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 3'
)
threshold = 0.2
y_val_probs = nn3.predict(X_val_best, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5980 | Optimal threshold: 0.3825 | Optimal Misclassification Cost: 0.5064 | Validation AUC: 0.615
# Plot decile chart
plot_decile_chart(nn3, X_val_best, Y_val, model_name = 'NN 3')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 3
nn3_metrics = {
'Model Name': 'NN 3',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_3.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn3_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
| 2 | NN 3 | 0.2 | 0.560748 | 0.597978 | 0.822939 | 0.300430 | 0.614514 | 0.382513 | 0.506432 |
Savepoint¶
X_training_best.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_training_best.pkl')
X_val_best.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_val_best.pkl')
X_test_best.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_test_best.pkl')
Q2e Step 8 - Model Architecture Tuning (Iteration 4)¶
# Q2e Step 8: Model construction
layer_configs = [
[128, 64], # 2 layers, wider-than-standard
[64, 32], # 2 layers, standard
[32, 16], # 2 layers, compact
[16, 8], # 2 layers, very compact
[128], # 1 layer, high capacity
[64], # 1 layer, standard
[32], # 1 layer, lightweight
[16] # 1 layer, very compact
]
dropout_rate = 0.3
threshold = 0.2
results = []
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
# Loop through architecture configurations
for layers_config in layer_configs:
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
model = tf.keras.Sequential()
model.add(layers.Input(shape=(X_training_best.shape[1],)))
for i, units in enumerate(layers_config):
activation = 'relu' if i == 0 else 'tanh'
model.add(layers.Dense(units, activation = activation))
model.add(layers.Dropout(dropout_rate))
model.add(layers.Dense(1, activation = 'sigmoid'))
model.compile(
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
history = model.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 0
)
# Compute metrics
val_loss = min(history.history['val_loss'])
y_val_probs = model.predict(X_val_best).ravel()
auc = roc_auc_score(Y_val, y_val_probs)
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
precision = precision_score(Y_val, y_val_preds, zero_division = 0)
recall = recall_score(Y_val, y_val_preds, zero_division = 0)
results.append({
'architecture': tuple(layers_config),
'Val Loss': val_loss,
'Cost-Sensitive Error': cost_sens_error,
f'Precision@{threshold}': precision,
f'Recall@{threshold}': recall,
'ROC_AUC': auc
})
print(f"Tested {layers_config} → Val Loss: {val_loss:.4f}, Cost-Sensitive Error: {cost_sens_error:.4f}, Precision@{threshold}: {precision:.4f}, Recall@{threshold}: {recall:.4f}, ROC_AUC: {auc:.4f}")
# Display sorted results by cost-sensitive error
display(pd.DataFrame(results).sort_values('Cost-Sensitive Error'))
# Plot validation loss vs architecture
plt.figure(figsize = (10, 5))
plt.plot(['-'.join(map(str, r['architecture'])) for r in results], [r['Val Loss'] for r in results], marker = 'o')
plt.xticks(rotation = 45)
plt.xlabel("Layer Architecture (i.e. Neurons per Layer)")
plt.ylabel("Validation Loss")
plt.title("Architecture Tuning with Top 20 SHAP Features (NN 4)")
plt.grid(True)
plt.tight_layout()
plt.show()
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 454us/step Tested [128, 64] → Val Loss: 0.5632, Cost-Sensitive Error: 0.6017, Precision@0.2: 0.2985, Recall@0.2: 0.8129, ROC_AUC: 0.6073 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 460us/step Tested [64, 32] → Val Loss: 0.5607, Cost-Sensitive Error: 0.5980, Precision@0.2: 0.3004, Recall@0.2: 0.8229, ROC_AUC: 0.6145 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 453us/step Tested [32, 16] → Val Loss: 0.5627, Cost-Sensitive Error: 0.5884, Precision@0.2: 0.3040, Recall@0.2: 0.8099, ROC_AUC: 0.6082 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 449us/step Tested [16, 8] → Val Loss: 0.5610, Cost-Sensitive Error: 0.5991, Precision@0.2: 0.2996, Recall@0.2: 0.8138, ROC_AUC: 0.6137 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 424us/step Tested [128] → Val Loss: 0.5634, Cost-Sensitive Error: 0.6099, Precision@0.2: 0.2952, Recall@0.2: 0.8147, ROC_AUC: 0.6053 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 437us/step Tested [64] → Val Loss: 0.5624, Cost-Sensitive Error: 0.6106, Precision@0.2: 0.2954, Recall@0.2: 0.8264, ROC_AUC: 0.6079 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 436us/step Tested [32] → Val Loss: 0.5626, Cost-Sensitive Error: 0.6147, Precision@0.2: 0.2942, Recall@0.2: 0.8382, ROC_AUC: 0.6094 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 428us/step Tested [16] → Val Loss: 0.5585, Cost-Sensitive Error: 0.6061, Precision@0.2: 0.2979, Recall@0.2: 0.8443, ROC_AUC: 0.6214
| architecture | Val Loss | Cost-Sensitive Error | Precision@0.2 | Recall@0.2 | ROC_AUC | |
|---|---|---|---|---|---|---|
| 2 | (32, 16) | 0.562679 | 0.588445 | 0.304028 | 0.809856 | 0.608207 |
| 1 | (64, 32) | 0.560748 | 0.597978 | 0.300430 | 0.822939 | 0.614514 |
| 3 | (16, 8) | 0.560996 | 0.599127 | 0.299615 | 0.813781 | 0.613749 |
| 0 | (128, 64) | 0.563165 | 0.601654 | 0.298527 | 0.812909 | 0.607274 |
| 7 | (16,) | 0.558486 | 0.606134 | 0.297892 | 0.844309 | 0.621378 |
| 4 | (128,) | 0.563419 | 0.609924 | 0.295196 | 0.814653 | 0.605325 |
| 5 | (64,) | 0.562363 | 0.610613 | 0.295401 | 0.826428 | 0.607932 |
| 6 | (32,) | 0.562638 | 0.614748 | 0.294244 | 0.838203 | 0.609381 |
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Refine the neural network architecture
nn4 = keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
# Decrease neurons in each layer
layers.Dense(16, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn4.compile(
optimizer = keras.optimizers.Adam(learning_rate = 0.001),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_4 = nn4.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn4.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.5221 - loss: 0.6302 - precision: 0.2607 - recall: 0.1392 - val_auc: 0.6141 - val_loss: 0.5617 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 970us/step - auc: 0.5874 - loss: 0.5618 - precision: 0.3400 - recall: 0.0024 - val_auc: 0.6140 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 915us/step - auc: 0.6044 - loss: 0.5575 - precision: 0.4489 - recall: 0.0024 - val_auc: 0.6131 - val_loss: 0.5611 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 902us/step - auc: 0.6122 - loss: 0.5551 - precision: 0.3961 - recall: 0.0031 - val_auc: 0.6107 - val_loss: 0.5621 - val_precision: 0.1667 - val_recall: 4.3611e-04 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 936us/step - auc: 0.6176 - loss: 0.5535 - precision: 0.4565 - recall: 0.0068 - val_auc: 0.6088 - val_loss: 0.5624 - val_precision: 0.2500 - val_recall: 4.3611e-04 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 910us/step - auc: 0.6206 - loss: 0.5529 - precision: 0.5214 - recall: 0.0057 - val_auc: 0.6076 - val_loss: 0.5630 - val_precision: 0.2857 - val_recall: 8.7222e-04 Epoch 7/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 904us/step - auc: 0.6244 - loss: 0.5517 - precision: 0.5178 - recall: 0.0077 - val_auc: 0.6070 - val_loss: 0.5635 - val_precision: 0.2500 - val_recall: 4.3611e-04
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 16) │ 336 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 8) │ 136 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 8) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 9 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 1,445 (5.65 KB)
Trainable params: 481 (1.88 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 964 (3.77 KB)
None
# Q2e Step 8: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_4, title = 'Training vs. Validation Loss (NN 4)')
p_va = nn4.predict(X_val_best, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.2)
sum_va = confusion_summary(y_va, p_va, 0.2)
print("\nConfusion Matrix Summary (0.2 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.2 threshold) (NN 4)')
plt.show()
Confusion Matrix Summary (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.2 | 2051 | 4362 | 427 | 1866 | 0.813781 | 0.299615 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 4'
)
threshold = 0.2
y_val_probs = nn4.predict(X_val_best, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5991 | Optimal threshold: 0.2874 | Optimal Misclassification Cost: 0.4968 | Validation AUC: 0.614
# Plot decile chart
plot_decile_chart(nn4, X_val_best, Y_val, model_name = 'NN 4')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 4
nn4_metrics = {
'Model Name': 'NN 4',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_4.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn4_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
| 2 | NN 3 | 0.2 | 0.560748 | 0.597978 | 0.822939 | 0.300430 | 0.614514 | 0.382513 | 0.506432 |
| 3 | NN 4 | 0.2 | 0.560996 | 0.599127 | 0.813781 | 0.299615 | 0.613749 | 0.287431 | 0.496784 |
Q2e Step 9 - Optimiser and Learn Rate Tuning (Iteration 5)¶
# Q2e Step 9: Model construction
optimisers_learnrates_config = [
('Adam', 0.005),
('Adam', 0.001),
('Adam', 0.0005),
('Adam', 0.0001),
('SGD', 0.005),
('SGD', 0.001),
('SGD', 0.0005),
('SGD', 0.0001)
]
threshold = 0.2
results = []
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
# Loop through optimiser and learn rate configurations
for opt_name, lr in optimisers_learnrates_config:
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Select optimiser
if opt_name == 'Adam':
optimiser = tf.keras.optimizers.Adam(learning_rate = lr)
elif opt_name == 'SGD':
optimiser = tf.keras.optimizers.SGD(learning_rate = lr, momentum = 0.9)
model = tf.keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
model.compile(
optimizer = optimiser,
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
history = model.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 0
)
# Compute metrics
val_loss = min(history.history['val_loss'])
y_val_probs = model.predict(X_val_best).ravel()
auc = roc_auc_score(Y_val, y_val_probs)
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
precision = precision_score(Y_val, y_val_preds, zero_division = 0)
recall = recall_score(Y_val, y_val_preds, zero_division = 0)
results.append({
'Optimiser': opt_name,
'Learn Rate': lr,
'Val Loss': val_loss,
'Cost-Sensitive Error': cost_sens_error,
f'Precision@{threshold}': precision,
f'Recall@{threshold}': recall,
'ROC_AUC': auc
})
print(f"Tested {opt_name} (lr = {lr}) → Val Loss: {val_loss:.4f}, Cost-Sensitive Error: {cost_sens_error:.4f}, Precision@{threshold}: {precision:.4f}, Recall@{threshold}: {recall:.4f}, ROC_AUC: {auc:.4f}")
# Display sorted results by cost-sensitive error
results_nn5 = pd.DataFrame(results)
display(results_nn5.sort_values('Cost-Sensitive Error'))
# Plotting
plt.figure(figsize = (10, 6))
# Subplot 1: Validation Loss
plt.subplot(2, 2, 1)
for opt in results_nn5['Optimiser'].unique():
subset = results_nn5[results_nn5['Optimiser'] == opt]
plt.plot(subset['Learn Rate'], subset['Val Loss'], marker = 'o', label = opt)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs Learning Rate')
plt.legend()
# Subplot 2: Cost-Sensitive Error
plt.subplot(2, 2, 2)
for opt in results_nn5['Optimiser'].unique():
subset = results_nn5[results_nn5['Optimiser'] == opt]
plt.plot(subset['Learn Rate'], subset['Cost-Sensitive Error'], marker = 'o', label = opt)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Cost-Sensitive Error')
plt.title('Cost-Sensitive Error vs Learning Rate')
plt.legend()
# Subplot 3: Precision and Recall
plt.subplot(2, 2, 3)
for opt in results_nn5['Optimiser'].unique():
subset = results_nn5[results_nn5['Optimiser'] == opt]
plt.plot(subset['Learn Rate'], subset[f'Precision@{threshold}'], marker = 'o', label = f'{opt} - Precision@{threshold}')
plt.plot(subset['Learn Rate'], subset[f'Recall@{threshold}'], marker = 'x', label = f'{opt} - Recall@{threshold}')
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Score')
plt.title('Precision and Recall vs Learning Rate')
plt.legend()
# Subplot 4: ROC_AUC
plt.subplot(2, 2, 4)
for opt in results_nn5['Optimiser'].unique():
subset = results_nn5[results_nn5['Optimiser'] == opt]
plt.plot(subset['Learn Rate'], subset['ROC_AUC'], marker = 'o', label = opt)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Score')
plt.title('ROC_AUC vs Learning Rate')
plt.legend()
plt.tight_layout()
plt.show()
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 452us/step Tested Adam (lr = 0.005) → Val Loss: 0.5621, Cost-Sensitive Error: 0.5918, Precision@0.2: 0.3020, Recall@0.2: 0.7924, ROC_AUC: 0.6100 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 454us/step Tested Adam (lr = 0.001) → Val Loss: 0.5610, Cost-Sensitive Error: 0.5991, Precision@0.2: 0.2996, Recall@0.2: 0.8138, ROC_AUC: 0.6137 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 467us/step Tested Adam (lr = 0.0005) → Val Loss: 0.5606, Cost-Sensitive Error: 0.5898, Precision@0.2: 0.3035, Recall@0.2: 0.8112, ROC_AUC: 0.6173 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 473us/step Tested Adam (lr = 0.0001) → Val Loss: 0.5607, Cost-Sensitive Error: 0.5887, Precision@0.2: 0.3039, Recall@0.2: 0.8103, ROC_AUC: 0.6187 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 461us/step Tested SGD (lr = 0.005) → Val Loss: 0.5608, Cost-Sensitive Error: 0.5802, Precision@0.2: 0.3075, Recall@0.2: 0.8051, ROC_AUC: 0.6165 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 466us/step Tested SGD (lr = 0.001) → Val Loss: 0.5605, Cost-Sensitive Error: 0.5820, Precision@0.2: 0.3066, Recall@0.2: 0.8007, ROC_AUC: 0.6183 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 460us/step Tested SGD (lr = 0.0005) → Val Loss: 0.5612, Cost-Sensitive Error: 0.5899, Precision@0.2: 0.3032, Recall@0.2: 0.8051, ROC_AUC: 0.6164 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 463us/step Tested SGD (lr = 0.0001) → Val Loss: 0.5663, Cost-Sensitive Error: 0.6545, Precision@0.2: 0.2824, Recall@0.2: 0.8962, ROC_AUC: 0.6061
| Optimiser | Learn Rate | Val Loss | Cost-Sensitive Error | Precision@0.2 | Recall@0.2 | ROC_AUC | |
|---|---|---|---|---|---|---|---|
| 4 | SGD | 0.0050 | 0.560756 | 0.580175 | 0.307513 | 0.805059 | 0.616471 |
| 5 | SGD | 0.0010 | 0.560492 | 0.582012 | 0.306562 | 0.800698 | 0.618289 |
| 3 | Adam | 0.0001 | 0.560706 | 0.588674 | 0.303942 | 0.810292 | 0.618668 |
| 2 | Adam | 0.0005 | 0.560640 | 0.589823 | 0.303475 | 0.811164 | 0.617317 |
| 6 | SGD | 0.0005 | 0.561215 | 0.589938 | 0.303219 | 0.805059 | 0.616390 |
| 0 | Adam | 0.0050 | 0.562116 | 0.591776 | 0.301978 | 0.792412 | 0.610009 |
| 1 | Adam | 0.0010 | 0.560996 | 0.599127 | 0.299615 | 0.813781 | 0.613749 |
| 7 | SGD | 0.0001 | 0.566298 | 0.654491 | 0.282397 | 0.896206 | 0.606052 |
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Refine the neural network architecture
nn5 = keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn5.compile(
optimizer = keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_5 = nn5.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn5.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.4991 - loss: 0.6511 - precision: 0.2511 - recall: 0.1766 - val_auc: 0.5571 - val_loss: 0.5730 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 904us/step - auc: 0.5325 - loss: 0.5706 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_auc: 0.5883 - val_loss: 0.5696 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 852us/step - auc: 0.5511 - loss: 0.5667 - precision: 0.6399 - recall: 4.5543e-04 - val_auc: 0.5981 - val_loss: 0.5683 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 876us/step - auc: 0.5646 - loss: 0.5641 - precision: 0.1334 - recall: 4.4333e-05 - val_auc: 0.6037 - val_loss: 0.5671 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.5687 - loss: 0.5637 - precision: 0.3660 - recall: 4.8436e-04 - val_auc: 0.6063 - val_loss: 0.5662 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 867us/step - auc: 0.5762 - loss: 0.5623 - precision: 0.6731 - recall: 5.4479e-04 - val_auc: 0.6076 - val_loss: 0.5655 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 7/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.5780 - loss: 0.5619 - precision: 0.6248 - recall: 0.0010 - val_auc: 0.6094 - val_loss: 0.5648 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 8/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 880us/step - auc: 0.5822 - loss: 0.5609 - precision: 0.4420 - recall: 7.3995e-04 - val_auc: 0.6099 - val_loss: 0.5644 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 9/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 876us/step - auc: 0.5898 - loss: 0.5593 - precision: 0.4178 - recall: 0.0013 - val_auc: 0.6109 - val_loss: 0.5640 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 10/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.5860 - loss: 0.5606 - precision: 0.6126 - recall: 0.0013 - val_auc: 0.6112 - val_loss: 0.5638 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 11/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 893us/step - auc: 0.5900 - loss: 0.5596 - precision: 0.6621 - recall: 0.0011 - val_auc: 0.6114 - val_loss: 0.5635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 12/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 883us/step - auc: 0.5923 - loss: 0.5591 - precision: 0.7082 - recall: 0.0016 - val_auc: 0.6117 - val_loss: 0.5633 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 13/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 844us/step - auc: 0.5921 - loss: 0.5591 - precision: 0.5131 - recall: 0.0012 - val_auc: 0.6123 - val_loss: 0.5630 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 14/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.5968 - loss: 0.5587 - precision: 0.3049 - recall: 8.6076e-04 - val_auc: 0.6130 - val_loss: 0.5628 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 15/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 874us/step - auc: 0.5980 - loss: 0.5583 - precision: 0.2970 - recall: 7.3695e-04 - val_auc: 0.6128 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 16/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 867us/step - auc: 0.5958 - loss: 0.5583 - precision: 0.4246 - recall: 8.2051e-04 - val_auc: 0.6127 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 17/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 852us/step - auc: 0.5974 - loss: 0.5581 - precision: 0.4016 - recall: 0.0015 - val_auc: 0.6132 - val_loss: 0.5625 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 18/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - auc: 0.6010 - loss: 0.5575 - precision: 0.6782 - recall: 0.0016 - val_auc: 0.6134 - val_loss: 0.5623 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 19/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 884us/step - auc: 0.6014 - loss: 0.5572 - precision: 0.5393 - recall: 0.0022 - val_auc: 0.6145 - val_loss: 0.5622 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 20/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 868us/step - auc: 0.6031 - loss: 0.5571 - precision: 0.5741 - recall: 0.0017 - val_auc: 0.6152 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 21/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.6042 - loss: 0.5569 - precision: 0.5381 - recall: 0.0012 - val_auc: 0.6154 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 22/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 866us/step - auc: 0.6062 - loss: 0.5562 - precision: 0.5181 - recall: 0.0020 - val_auc: 0.6156 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 23/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6050 - loss: 0.5568 - precision: 0.4053 - recall: 0.0013 - val_auc: 0.6153 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 24/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 874us/step - auc: 0.6064 - loss: 0.5563 - precision: 0.4677 - recall: 0.0019 - val_auc: 0.6153 - val_loss: 0.5615 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 25/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 871us/step - auc: 0.6103 - loss: 0.5555 - precision: 0.3198 - recall: 0.0011 - val_auc: 0.6156 - val_loss: 0.5614 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 26/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - auc: 0.6104 - loss: 0.5553 - precision: 0.4232 - recall: 0.0016 - val_auc: 0.6162 - val_loss: 0.5613 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 27/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 867us/step - auc: 0.6091 - loss: 0.5555 - precision: 0.4537 - recall: 0.0013 - val_auc: 0.6163 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 28/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.6086 - loss: 0.5557 - precision: 0.4519 - recall: 0.0016 - val_auc: 0.6167 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 29/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.6104 - loss: 0.5554 - precision: 0.3590 - recall: 0.0018 - val_auc: 0.6168 - val_loss: 0.5611 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 30/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.6119 - loss: 0.5552 - precision: 0.4485 - recall: 0.0013 - val_auc: 0.6175 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 31/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 871us/step - auc: 0.6140 - loss: 0.5547 - precision: 0.6222 - recall: 0.0019 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 32/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 862us/step - auc: 0.6108 - loss: 0.5552 - precision: 0.6002 - recall: 0.0015 - val_auc: 0.6178 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 33/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.6135 - loss: 0.5546 - precision: 0.3840 - recall: 0.0014 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 34/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.6137 - loss: 0.5549 - precision: 0.4721 - recall: 0.0013 - val_auc: 0.6179 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 35/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 863us/step - auc: 0.6133 - loss: 0.5546 - precision: 0.5580 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 36/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 887us/step - auc: 0.6135 - loss: 0.5544 - precision: 0.3738 - recall: 0.0013 - val_auc: 0.6178 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 37/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 910us/step - auc: 0.6144 - loss: 0.5542 - precision: 0.5167 - recall: 0.0016 - val_auc: 0.6183 - val_loss: 0.5607 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 38/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.6134 - loss: 0.5547 - precision: 0.4703 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 39/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 884us/step - auc: 0.6177 - loss: 0.5533 - precision: 0.5408 - recall: 0.0020 - val_auc: 0.6181 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 40/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 863us/step - auc: 0.6208 - loss: 0.5526 - precision: 0.4139 - recall: 0.0015 - val_auc: 0.6185 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 41/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 882us/step - auc: 0.6193 - loss: 0.5528 - precision: 0.5167 - recall: 0.0020 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 42/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 866us/step - auc: 0.6153 - loss: 0.5541 - precision: 0.3757 - recall: 0.0019 - val_auc: 0.6185 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 43/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 885us/step - auc: 0.6152 - loss: 0.5543 - precision: 0.3453 - recall: 0.0011 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 44/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 887us/step - auc: 0.6161 - loss: 0.5539 - precision: 0.5181 - recall: 0.0018 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 45/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 861us/step - auc: 0.6152 - loss: 0.5539 - precision: 0.5382 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 46/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 868us/step - auc: 0.6190 - loss: 0.5529 - precision: 0.4149 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 47/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 858us/step - auc: 0.6171 - loss: 0.5536 - precision: 0.3479 - recall: 9.3537e-04 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 16) │ 336 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 8) │ 136 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 8) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 9 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 964 (3.77 KB)
Trainable params: 481 (1.88 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 483 (1.89 KB)
None
# Q2e Step 9: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_5, title = 'Training vs. Validation Loss (NN 5)')
p_va = nn5.predict(X_val_best, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.2)
sum_va = confusion_summary(y_va, p_va, 0.2)
print("\nConfusion Matrix Summary (0.2 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.2 threshold) (NN 5)')
plt.show()
Confusion Matrix Summary (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.2 | 2260 | 4153 | 457 | 1836 | 0.800698 | 0.306562 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 5'
)
threshold = 0.2
y_val_probs = nn5.predict(X_val_best, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5820 | Optimal threshold: 0.3263 | Optimal Misclassification Cost: 0.5054 | Validation AUC: 0.618
# Plot decile chart
plot_decile_chart(nn5, X_val_best, Y_val, model_name = 'NN 5')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 5
nn5_metrics = {
'Model Name': 'NN 5',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_5.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn5_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
| 2 | NN 3 | 0.2 | 0.560748 | 0.597978 | 0.822939 | 0.300430 | 0.614514 | 0.382513 | 0.506432 |
| 3 | NN 4 | 0.2 | 0.560996 | 0.599127 | 0.813781 | 0.299615 | 0.613749 | 0.287431 | 0.496784 |
| 4 | NN 5 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
Q2e Step 10 - Batch Size and Epochs Tuning (Iteration 6)¶
# Q2e Step 10: Model construction
batch_sizes = [16, 32, 64]
epochs_options = [30, 50, 70]
dropout_rate = 0.3
threshold = 0.2
results = []
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
# Loop through batch sizes and epoch configurations
for batch_size in batch_sizes:
for epochs in epochs_options:
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
model = tf.keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu'),
layers.Dropout(dropout_rate),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(dropout_rate),
layers.Dense(1, activation = 'sigmoid')
])
model.compile(
optimizer = keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
])
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
history = model.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = epochs,
batch_size = batch_size,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 0
)
# Compute metrics
final_val_loss = min(history.history['val_loss'])
y_val_probs = model.predict(X_val_best).ravel()
auc = roc_auc_score(Y_val, y_val_probs)
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
precision = precision_score(Y_val, y_val_preds, zero_division = 0)
recall = recall_score(Y_val, y_val_preds, zero_division = 0)
results.append({
'Batch Size': batch_size,
'Epochs': epochs,
'Val Loss': val_loss,
'Cost-Sensitive Error': cost_sens_error,
f'Precision@{threshold}': precision,
f'Recall@{threshold}': recall,
'ROC_AUC': auc
})
print(f"Batch Size {batch_size}, Epochs {epochs} → Val Loss: {val_loss:.4f}, Cost-Sensitive Error: {cost_sens_error:.4f}, Precision@{threshold}: {precision:.4f}, Recall@{threshold}: {recall:.4f}, ROC_AUC: {auc:.4f}")
# Display sorted results by cost-sensitive error
results_nn6 = pd.DataFrame(results)
display(results_nn6.sort_values('Cost-Sensitive Error'))
# Plot metrics
metrics = ['Val Loss', 'Cost-Sensitive Error', f"Precision@{threshold}", f"Recall@{threshold}", 'ROC_AUC']
fig, axes = plt.subplots(3, 2, figsize = (12, 10))
axes = axes.flatten()
for i, metric in enumerate(metrics):
for batch_size in results_nn6['Batch Size'].unique():
subset = results_nn6[results_nn6['Batch Size'] == batch_size]
axes[i].plot(subset['Epochs'], subset[metric], marker = 'o', label = f'Batch {batch_size}')
axes[i].set_title(metric)
axes[i].set_xlabel('Epochs')
axes[i].set_ylabel(metric)
axes[i].legend()
axes[i].grid(True)
# Turn off the axis for the last subplot
axes[-1].axis('off')
plt.tight_layout()
plt.show()
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 445us/step Batch Size 16, Epochs 30 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5787, Precision@0.2: 0.3080, Recall@0.2: 0.7976, ROC_AUC: 0.6180 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 465us/step Batch Size 16, Epochs 50 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5787, Precision@0.2: 0.3080, Recall@0.2: 0.7976, ROC_AUC: 0.6180 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 458us/step Batch Size 16, Epochs 70 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5787, Precision@0.2: 0.3080, Recall@0.2: 0.7976, ROC_AUC: 0.6180 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 444us/step Batch Size 32, Epochs 30 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5836, Precision@0.2: 0.3058, Recall@0.2: 0.8007, ROC_AUC: 0.6174 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 478us/step Batch Size 32, Epochs 50 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5820, Precision@0.2: 0.3066, Recall@0.2: 0.8007, ROC_AUC: 0.6183 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 443us/step Batch Size 32, Epochs 70 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5820, Precision@0.2: 0.3066, Recall@0.2: 0.8007, ROC_AUC: 0.6183 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 447us/step Batch Size 64, Epochs 30 → Val Loss: 0.5663, Cost-Sensitive Error: 0.6030, Precision@0.2: 0.2982, Recall@0.2: 0.8203, ROC_AUC: 0.6138 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 459us/step Batch Size 64, Epochs 50 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5903, Precision@0.2: 0.3033, Recall@0.2: 0.8107, ROC_AUC: 0.6159 273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 529us/step Batch Size 64, Epochs 70 → Val Loss: 0.5663, Cost-Sensitive Error: 0.5818, Precision@0.2: 0.3066, Recall@0.2: 0.7981, ROC_AUC: 0.6182
| Batch Size | Epochs | Val Loss | Cost-Sensitive Error | Precision@0.2 | Recall@0.2 | ROC_AUC | |
|---|---|---|---|---|---|---|---|
| 0 | 16 | 30 | 0.566298 | 0.578681 | 0.307964 | 0.797645 | 0.618037 |
| 1 | 16 | 50 | 0.566298 | 0.578681 | 0.307964 | 0.797645 | 0.618037 |
| 2 | 16 | 70 | 0.566298 | 0.578681 | 0.307964 | 0.797645 | 0.618037 |
| 8 | 64 | 70 | 0.566298 | 0.581783 | 0.306584 | 0.798081 | 0.618237 |
| 5 | 32 | 70 | 0.566298 | 0.582012 | 0.306562 | 0.800698 | 0.618289 |
| 4 | 32 | 50 | 0.566298 | 0.582012 | 0.306562 | 0.800698 | 0.618289 |
| 3 | 32 | 30 | 0.566298 | 0.583620 | 0.305847 | 0.800698 | 0.617392 |
| 7 | 64 | 50 | 0.566298 | 0.590283 | 0.303263 | 0.810728 | 0.615881 |
| 6 | 64 | 30 | 0.566298 | 0.603032 | 0.298240 | 0.820323 | 0.613757 |
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Refine the neural network architecture
nn6 = keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn6.compile(
optimizer = keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_6 = nn6.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn6.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.4991 - loss: 0.6511 - precision: 0.2511 - recall: 0.1766 - val_auc: 0.5571 - val_loss: 0.5730 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.5325 - loss: 0.5706 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_auc: 0.5883 - val_loss: 0.5696 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 898us/step - auc: 0.5511 - loss: 0.5667 - precision: 0.6399 - recall: 4.5543e-04 - val_auc: 0.5981 - val_loss: 0.5683 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 897us/step - auc: 0.5646 - loss: 0.5641 - precision: 0.1334 - recall: 4.4333e-05 - val_auc: 0.6037 - val_loss: 0.5671 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 881us/step - auc: 0.5687 - loss: 0.5637 - precision: 0.3660 - recall: 4.8436e-04 - val_auc: 0.6063 - val_loss: 0.5662 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.5762 - loss: 0.5623 - precision: 0.6731 - recall: 5.4479e-04 - val_auc: 0.6076 - val_loss: 0.5655 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 7/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 881us/step - auc: 0.5780 - loss: 0.5619 - precision: 0.6248 - recall: 0.0010 - val_auc: 0.6094 - val_loss: 0.5648 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 8/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.5822 - loss: 0.5609 - precision: 0.4420 - recall: 7.3995e-04 - val_auc: 0.6099 - val_loss: 0.5644 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 9/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 884us/step - auc: 0.5898 - loss: 0.5593 - precision: 0.4178 - recall: 0.0013 - val_auc: 0.6109 - val_loss: 0.5640 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 10/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 857us/step - auc: 0.5860 - loss: 0.5606 - precision: 0.6126 - recall: 0.0013 - val_auc: 0.6112 - val_loss: 0.5638 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 11/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 841us/step - auc: 0.5900 - loss: 0.5596 - precision: 0.6621 - recall: 0.0011 - val_auc: 0.6114 - val_loss: 0.5635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 12/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - auc: 0.5923 - loss: 0.5591 - precision: 0.7082 - recall: 0.0016 - val_auc: 0.6117 - val_loss: 0.5633 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 13/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 893us/step - auc: 0.5921 - loss: 0.5591 - precision: 0.5131 - recall: 0.0012 - val_auc: 0.6123 - val_loss: 0.5630 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 14/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 902us/step - auc: 0.5968 - loss: 0.5587 - precision: 0.3049 - recall: 8.6076e-04 - val_auc: 0.6130 - val_loss: 0.5628 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 15/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - auc: 0.5980 - loss: 0.5583 - precision: 0.2970 - recall: 7.3695e-04 - val_auc: 0.6128 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 16/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 892us/step - auc: 0.5958 - loss: 0.5583 - precision: 0.4246 - recall: 8.2051e-04 - val_auc: 0.6127 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 17/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 885us/step - auc: 0.5974 - loss: 0.5581 - precision: 0.4016 - recall: 0.0015 - val_auc: 0.6132 - val_loss: 0.5625 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 18/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 887us/step - auc: 0.6010 - loss: 0.5575 - precision: 0.6782 - recall: 0.0016 - val_auc: 0.6134 - val_loss: 0.5623 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 19/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 896us/step - auc: 0.6014 - loss: 0.5572 - precision: 0.5393 - recall: 0.0022 - val_auc: 0.6145 - val_loss: 0.5622 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 20/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 855us/step - auc: 0.6031 - loss: 0.5571 - precision: 0.5741 - recall: 0.0017 - val_auc: 0.6152 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 21/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 902us/step - auc: 0.6042 - loss: 0.5569 - precision: 0.5381 - recall: 0.0012 - val_auc: 0.6154 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 22/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 885us/step - auc: 0.6062 - loss: 0.5562 - precision: 0.5181 - recall: 0.0020 - val_auc: 0.6156 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 23/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 891us/step - auc: 0.6050 - loss: 0.5568 - precision: 0.4053 - recall: 0.0013 - val_auc: 0.6153 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 24/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6064 - loss: 0.5563 - precision: 0.4677 - recall: 0.0019 - val_auc: 0.6153 - val_loss: 0.5615 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 25/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6103 - loss: 0.5555 - precision: 0.3198 - recall: 0.0011 - val_auc: 0.6156 - val_loss: 0.5614 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 26/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 858us/step - auc: 0.6104 - loss: 0.5553 - precision: 0.4232 - recall: 0.0016 - val_auc: 0.6162 - val_loss: 0.5613 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 27/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.6091 - loss: 0.5555 - precision: 0.4537 - recall: 0.0013 - val_auc: 0.6163 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 28/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.6086 - loss: 0.5557 - precision: 0.4519 - recall: 0.0016 - val_auc: 0.6167 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 29/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 884us/step - auc: 0.6104 - loss: 0.5554 - precision: 0.3590 - recall: 0.0018 - val_auc: 0.6168 - val_loss: 0.5611 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 30/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.6119 - loss: 0.5552 - precision: 0.4485 - recall: 0.0013 - val_auc: 0.6175 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 31/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.6140 - loss: 0.5547 - precision: 0.6222 - recall: 0.0019 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 32/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6108 - loss: 0.5552 - precision: 0.6002 - recall: 0.0015 - val_auc: 0.6178 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 33/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 864us/step - auc: 0.6135 - loss: 0.5546 - precision: 0.3840 - recall: 0.0014 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 34/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 860us/step - auc: 0.6137 - loss: 0.5549 - precision: 0.4721 - recall: 0.0013 - val_auc: 0.6179 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 35/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 869us/step - auc: 0.6133 - loss: 0.5546 - precision: 0.5580 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 36/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 866us/step - auc: 0.6135 - loss: 0.5544 - precision: 0.3738 - recall: 0.0013 - val_auc: 0.6178 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 37/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 868us/step - auc: 0.6144 - loss: 0.5542 - precision: 0.5167 - recall: 0.0016 - val_auc: 0.6183 - val_loss: 0.5607 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 38/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.6134 - loss: 0.5547 - precision: 0.4703 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 39/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 897us/step - auc: 0.6177 - loss: 0.5533 - precision: 0.5408 - recall: 0.0020 - val_auc: 0.6181 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 40/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6208 - loss: 0.5526 - precision: 0.4139 - recall: 0.0015 - val_auc: 0.6185 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 41/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 871us/step - auc: 0.6193 - loss: 0.5528 - precision: 0.5167 - recall: 0.0020 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 42/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6153 - loss: 0.5541 - precision: 0.3757 - recall: 0.0019 - val_auc: 0.6185 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 43/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6152 - loss: 0.5543 - precision: 0.3453 - recall: 0.0011 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 44/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 881us/step - auc: 0.6161 - loss: 0.5539 - precision: 0.5181 - recall: 0.0018 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 45/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 887us/step - auc: 0.6152 - loss: 0.5539 - precision: 0.5382 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 46/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.6190 - loss: 0.5529 - precision: 0.4149 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 47/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 891us/step - auc: 0.6171 - loss: 0.5536 - precision: 0.3479 - recall: 9.3537e-04 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 16) │ 336 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 8) │ 136 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 8) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 9 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 964 (3.77 KB)
Trainable params: 481 (1.88 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 483 (1.89 KB)
None
# Q2e Step 10: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_6, title = 'Training vs. Validation Loss (NN 6)')
p_va = nn6.predict(X_val_best, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.2)
sum_va = confusion_summary(y_va, p_va, 0.2)
print("\nConfusion Matrix Summary (0.2 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.2 threshold) (NN 6)')
plt.show()
Confusion Matrix Summary (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.2 | 2260 | 4153 | 457 | 1836 | 0.800698 | 0.306562 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 6'
)
threshold = 0.2
y_val_probs = nn6.predict(X_val_best, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5820 | Optimal threshold: 0.3263 | Optimal Misclassification Cost: 0.5054 | Validation AUC: 0.618
# Plot decile chart
plot_decile_chart(nn6, X_val_best, Y_val, model_name = 'NN 6')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 6
nn6_metrics = {
'Model Name': 'NN 6',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_6.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn6_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
| 2 | NN 3 | 0.2 | 0.560748 | 0.597978 | 0.822939 | 0.300430 | 0.614514 | 0.382513 | 0.506432 |
| 3 | NN 4 | 0.2 | 0.560996 | 0.599127 | 0.813781 | 0.299615 | 0.613749 | 0.287431 | 0.496784 |
| 4 | NN 5 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
| 5 | NN 6 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
Q2e Step 11 - Regularisation Tuning (Iteration 7)¶
# Q2e Step 11: Model construction
dropout_rates = [0.1, 0.3]
l1_lambdas = [0.0, 1e-4, 1e-5]
l2_lambdas = [0.0, 1e-4, 1e-5]
threshold = 0.2
results = []
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
# Loop through regularisation configurations
for dr in dropout_rates:
for l1 in l1_lambdas:
for l2 in l2_lambdas:
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
model = tf.keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu',
kernel_regularizer = regularizers.l1_l2(l1 = l1, l2 = l2)),
layers.Dropout(dr),
layers.Dense(8, activation = 'tanh',
kernel_regularizer = regularizers.l1_l2(l1 = l1, l2 = l2)),
layers.Dropout(dr),
layers.Dense(1, activation = 'sigmoid')
])
model.compile(
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
])
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
history = model.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 0
)
# Compute metrics
val_loss = min(history.history['val_loss'])
y_val_probs = model.predict(X_val_best).ravel()
auc = roc_auc_score(Y_val, y_val_probs)
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
precision = precision_score(Y_val, y_val_preds, zero_division = 0)
recall = recall_score(Y_val, y_val_preds, zero_division = 0)
results.append({
'Dropout Rate': dr,
'L1 Lambda': l1,
'L2 Lambda': l2,
'Val Loss': val_loss,
'Cost-Sensitive Error': cost_sens_error,
f'Precision@{threshold}': precision,
f'Recall@{threshold}': recall,
'ROC_AUC': auc
})
print(f"Dropout {dr}, L1 {l1}, L2 {l2} → Val Loss: {val_loss:.4f}, Cost-Sensitive Error: {cost_sens_error:.4f}, Precision@{threshold}: {precision:.4f}, Recall@{threshold}: {recall:.4f}, ROC_AUC: {auc:.4f}")
# Display sorted results by cost-sensitive error
results_nn7 = pd.DataFrame(results)
def highlight_best(df):
styles = pd.DataFrame('', index = df.index, columns = df.columns)
styles.loc[9, :] = 'background-color: lightblue; font-weight: bold;'
for col in [f'Recall@{threshold}', f'Precision@{threshold}', 'ROC_AUC']:
styles.loc[df[col] == df[col].max(), col] = ('background-color: lightgreen; font-weight: bold;')
for col in ['Val Loss', 'Cost-Sensitive Error']:
styles.loc[df[col] == df[col].min(), col] = ('background-color: lightgreen; font-weight: bold;')
return styles
# Apply style to results
results_nn7.style.apply(highlight_best, axis = None).format({
'Val Loss': "{:.4f}",
'Cost-Sensitive Error': "{:.4f}",
f'Precision@{threshold}': "{:.4f}",
f'Recall@{threshold}': "{:.4f}",
'ROC_AUC': "{:.4f}"
})
| Dropout Rate | L1 Lambda | L2 Lambda | Val Loss | Cost-Sensitive Error | Precision@0.2 | Recall@0.2 | ROC_AUC | |
|---|---|---|---|---|---|---|---|---|
| 0 | 0.100000 | 0.000000 | 0.000000 | 0.5600 | 0.5747 | 0.3096 | 0.7894 | 0.6199 |
| 1 | 0.100000 | 0.000000 | 0.000100 | 0.5623 | 0.5760 | 0.3091 | 0.7950 | 0.6195 |
| 2 | 0.100000 | 0.000000 | 0.000010 | 0.5603 | 0.5748 | 0.3095 | 0.7902 | 0.6198 |
| 3 | 0.100000 | 0.000100 | 0.000000 | 0.5677 | 0.5782 | 0.3081 | 0.7955 | 0.6184 |
| 4 | 0.100000 | 0.000100 | 0.000100 | 0.5688 | 0.5807 | 0.3071 | 0.8007 | 0.6176 |
| 5 | 0.100000 | 0.000100 | 0.000010 | 0.5678 | 0.5788 | 0.3079 | 0.7959 | 0.6183 |
| 6 | 0.100000 | 0.000010 | 0.000000 | 0.5610 | 0.5743 | 0.3098 | 0.7907 | 0.6197 |
| 7 | 0.100000 | 0.000010 | 0.000100 | 0.5630 | 0.5759 | 0.3092 | 0.7959 | 0.6193 |
| 8 | 0.100000 | 0.000010 | 0.000010 | 0.5612 | 0.5739 | 0.3100 | 0.7915 | 0.6196 |
| 9 | 0.300000 | 0.000000 | 0.000000 | 0.5605 | 0.5820 | 0.3066 | 0.8007 | 0.6183 |
| 10 | 0.300000 | 0.000000 | 0.000100 | 0.5624 | 0.5845 | 0.3058 | 0.8120 | 0.6180 |
| 11 | 0.300000 | 0.000000 | 0.000010 | 0.5607 | 0.5821 | 0.3065 | 0.8016 | 0.6182 |
| 12 | 0.300000 | 0.000100 | 0.000000 | 0.5668 | 0.5851 | 0.3056 | 0.8129 | 0.6173 |
| 13 | 0.300000 | 0.000100 | 0.000100 | 0.5671 | 0.5907 | 0.3035 | 0.8225 | 0.6168 |
| 14 | 0.300000 | 0.000100 | 0.000010 | 0.5668 | 0.5860 | 0.3052 | 0.8147 | 0.6172 |
| 15 | 0.300000 | 0.000010 | 0.000000 | 0.5614 | 0.5820 | 0.3066 | 0.8029 | 0.6181 |
| 16 | 0.300000 | 0.000010 | 0.000100 | 0.5631 | 0.5844 | 0.3059 | 0.8133 | 0.6179 |
| 17 | 0.300000 | 0.000010 | 0.000010 | 0.5616 | 0.5820 | 0.3067 | 0.8038 | 0.6181 |
tf.keras.backend.clear_session()
keras_deterministic(seed_value = 0)
# Refine the neural network architecture
nn7 = keras.Sequential([
layers.Input(shape = (X_training_best.shape[1],)),
layers.Dense(16, activation = 'relu'),
layers.Dropout(0.3),
layers.Dense(8, activation = 'tanh'),
layers.Dropout(0.3),
layers.Dense(1, activation = 'sigmoid')
])
# Compile the model
nn7.compile(
optimizer = keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'binary_crossentropy',
metrics = [
tf.keras.metrics.Precision(name = 'precision'),
tf.keras.metrics.Recall(name = 'recall'),
tf.keras.metrics.AUC(name = 'auc')
]
)
# Fit the model with early stopping
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)
nn_hist_7 = nn7.fit(
X_training_best, Y_training,
validation_data = (X_val_best, Y_val),
epochs = 50,
batch_size = 32,
callbacks = [early_stopping],
# class_weight = class_weight_dict,
verbose = 1
)
# Check model parameters
display(nn7.summary())
Epoch 1/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - auc: 0.4991 - loss: 0.6511 - precision: 0.2511 - recall: 0.1766 - val_auc: 0.5571 - val_loss: 0.5730 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.5325 - loss: 0.5706 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_auc: 0.5883 - val_loss: 0.5696 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 882us/step - auc: 0.5511 - loss: 0.5667 - precision: 0.6399 - recall: 4.5543e-04 - val_auc: 0.5981 - val_loss: 0.5683 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 896us/step - auc: 0.5646 - loss: 0.5641 - precision: 0.1334 - recall: 4.4333e-05 - val_auc: 0.6037 - val_loss: 0.5671 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 5/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 882us/step - auc: 0.5687 - loss: 0.5637 - precision: 0.3660 - recall: 4.8436e-04 - val_auc: 0.6063 - val_loss: 0.5662 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 6/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 893us/step - auc: 0.5762 - loss: 0.5623 - precision: 0.6731 - recall: 5.4479e-04 - val_auc: 0.6076 - val_loss: 0.5655 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 7/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 906us/step - auc: 0.5780 - loss: 0.5619 - precision: 0.6248 - recall: 0.0010 - val_auc: 0.6094 - val_loss: 0.5648 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 8/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 877us/step - auc: 0.5822 - loss: 0.5609 - precision: 0.4420 - recall: 7.3995e-04 - val_auc: 0.6099 - val_loss: 0.5644 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 9/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 879us/step - auc: 0.5898 - loss: 0.5593 - precision: 0.4178 - recall: 0.0013 - val_auc: 0.6109 - val_loss: 0.5640 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 10/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.5860 - loss: 0.5606 - precision: 0.6126 - recall: 0.0013 - val_auc: 0.6112 - val_loss: 0.5638 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 11/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 853us/step - auc: 0.5900 - loss: 0.5596 - precision: 0.6621 - recall: 0.0011 - val_auc: 0.6114 - val_loss: 0.5635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 12/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 896us/step - auc: 0.5923 - loss: 0.5591 - precision: 0.7082 - recall: 0.0016 - val_auc: 0.6117 - val_loss: 0.5633 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 13/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 876us/step - auc: 0.5921 - loss: 0.5591 - precision: 0.5131 - recall: 0.0012 - val_auc: 0.6123 - val_loss: 0.5630 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 14/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 898us/step - auc: 0.5968 - loss: 0.5587 - precision: 0.3049 - recall: 8.6076e-04 - val_auc: 0.6130 - val_loss: 0.5628 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 15/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 903us/step - auc: 0.5980 - loss: 0.5583 - precision: 0.2970 - recall: 7.3695e-04 - val_auc: 0.6128 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 16/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 918us/step - auc: 0.5958 - loss: 0.5583 - precision: 0.4246 - recall: 8.2051e-04 - val_auc: 0.6127 - val_loss: 0.5627 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 17/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 892us/step - auc: 0.5974 - loss: 0.5581 - precision: 0.4016 - recall: 0.0015 - val_auc: 0.6132 - val_loss: 0.5625 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 18/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 869us/step - auc: 0.6010 - loss: 0.5575 - precision: 0.6782 - recall: 0.0016 - val_auc: 0.6134 - val_loss: 0.5623 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 19/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6014 - loss: 0.5572 - precision: 0.5393 - recall: 0.0022 - val_auc: 0.6145 - val_loss: 0.5622 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 20/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 872us/step - auc: 0.6031 - loss: 0.5571 - precision: 0.5741 - recall: 0.0017 - val_auc: 0.6152 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 21/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 890us/step - auc: 0.6042 - loss: 0.5569 - precision: 0.5381 - recall: 0.0012 - val_auc: 0.6154 - val_loss: 0.5618 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 22/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 876us/step - auc: 0.6062 - loss: 0.5562 - precision: 0.5181 - recall: 0.0020 - val_auc: 0.6156 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 23/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.6050 - loss: 0.5568 - precision: 0.4053 - recall: 0.0013 - val_auc: 0.6153 - val_loss: 0.5616 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 24/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 893us/step - auc: 0.6064 - loss: 0.5563 - precision: 0.4677 - recall: 0.0019 - val_auc: 0.6153 - val_loss: 0.5615 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 25/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.6103 - loss: 0.5555 - precision: 0.3198 - recall: 0.0011 - val_auc: 0.6156 - val_loss: 0.5614 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 26/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 873us/step - auc: 0.6104 - loss: 0.5553 - precision: 0.4232 - recall: 0.0016 - val_auc: 0.6162 - val_loss: 0.5613 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 27/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 892us/step - auc: 0.6091 - loss: 0.5555 - precision: 0.4537 - recall: 0.0013 - val_auc: 0.6163 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 28/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 890us/step - auc: 0.6086 - loss: 0.5557 - precision: 0.4519 - recall: 0.0016 - val_auc: 0.6167 - val_loss: 0.5612 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 29/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 893us/step - auc: 0.6104 - loss: 0.5554 - precision: 0.3590 - recall: 0.0018 - val_auc: 0.6168 - val_loss: 0.5611 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 30/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 889us/step - auc: 0.6119 - loss: 0.5552 - precision: 0.4485 - recall: 0.0013 - val_auc: 0.6175 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 31/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6140 - loss: 0.5547 - precision: 0.6222 - recall: 0.0019 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 32/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6108 - loss: 0.5552 - precision: 0.6002 - recall: 0.0015 - val_auc: 0.6178 - val_loss: 0.5610 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 33/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 876us/step - auc: 0.6135 - loss: 0.5546 - precision: 0.3840 - recall: 0.0014 - val_auc: 0.6174 - val_loss: 0.5609 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 34/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 885us/step - auc: 0.6137 - loss: 0.5549 - precision: 0.4721 - recall: 0.0013 - val_auc: 0.6179 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 35/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 883us/step - auc: 0.6133 - loss: 0.5546 - precision: 0.5580 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 36/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 878us/step - auc: 0.6135 - loss: 0.5544 - precision: 0.3738 - recall: 0.0013 - val_auc: 0.6178 - val_loss: 0.5608 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 37/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 892us/step - auc: 0.6144 - loss: 0.5542 - precision: 0.5167 - recall: 0.0016 - val_auc: 0.6183 - val_loss: 0.5607 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 38/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 862us/step - auc: 0.6134 - loss: 0.5547 - precision: 0.4703 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 39/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 856us/step - auc: 0.6177 - loss: 0.5533 - precision: 0.5408 - recall: 0.0020 - val_auc: 0.6181 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 40/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6208 - loss: 0.5526 - precision: 0.4139 - recall: 0.0015 - val_auc: 0.6185 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 41/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 882us/step - auc: 0.6193 - loss: 0.5528 - precision: 0.5167 - recall: 0.0020 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 42/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 875us/step - auc: 0.6153 - loss: 0.5541 - precision: 0.3757 - recall: 0.0019 - val_auc: 0.6185 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 43/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 866us/step - auc: 0.6152 - loss: 0.5543 - precision: 0.3453 - recall: 0.0011 - val_auc: 0.6180 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 44/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 886us/step - auc: 0.6161 - loss: 0.5539 - precision: 0.5181 - recall: 0.0018 - val_auc: 0.6187 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 45/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 884us/step - auc: 0.6152 - loss: 0.5539 - precision: 0.5382 - recall: 0.0021 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 46/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 863us/step - auc: 0.6190 - loss: 0.5529 - precision: 0.4149 - recall: 0.0022 - val_auc: 0.6177 - val_loss: 0.5606 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 47/50 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 880us/step - auc: 0.6171 - loss: 0.5536 - precision: 0.3479 - recall: 9.3537e-04 - val_auc: 0.6180 - val_loss: 0.5605 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 16) │ 336 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 16) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 8) │ 136 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 8) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 9 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 964 (3.77 KB)
Trainable params: 481 (1.88 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 483 (1.89 KB)
None
# Q2e Step 11: Analysis of model behaviours and drivers
# Plot training and validation loss
plot_training_val_loss(nn_hist_7, title = 'Training vs. Validation Loss (NN 7)')
p_va = nn7.predict(X_val_best, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
# Show confusion matrix
cm_va = confusion_matrix_df(y_va, p_va, 0.2)
sum_va = confusion_summary(y_va, p_va, 0.2)
print("\nConfusion Matrix Summary (0.2 threshold):\n")
display(sum_va)
plot_confusion_matrix(cm_va, title = 'Validation Confusion Matrix (0.2 threshold) (NN 7)')
plt.show()
Confusion Matrix Summary (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 8706 | 0.2 | 2260 | 4153 | 457 | 1836 | 0.800698 | 0.306562 |
# Show threshold tuning with optimal threshold
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
model_name = 'NN 7'
)
threshold = 0.2
y_val_probs = nn7.predict(X_val_best, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
print(f"Misclassification Cost: {cost_sens_error:.4f} | Optimal threshold: {thr_star:.4f} | Optimal Misclassification Cost: {val_cost:.4f} | Validation AUC: {auc:.3f}")
Misclassification Cost: 0.5820 | Optimal threshold: 0.3263 | Optimal Misclassification Cost: 0.5054 | Validation AUC: 0.618
# Plot decile chart
plot_decile_chart(nn7, X_val_best, Y_val, model_name = 'NN 7')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
# Create a new row with the metrics for NN 7
nn7_metrics = {
'Model Name': 'NN 7',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_7.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df = pd.concat([model_metrics_df, pd.DataFrame([nn7_metrics])], ignore_index = True)
# Display the combined results
display(model_metrics_df)
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.4 | 0.567409 | 0.514473 | 0.074139 | 0.421836 | 0.596707 | 0.402043 | 0.513439 |
| 1 | NN 2 | 0.4 | 0.673952 | 0.624971 | 0.835587 | 0.290171 | 0.598727 | 0.662289 | 0.508959 |
| 2 | NN 3 | 0.2 | 0.560748 | 0.597978 | 0.822939 | 0.300430 | 0.614514 | 0.382513 | 0.506432 |
| 3 | NN 4 | 0.2 | 0.560996 | 0.599127 | 0.813781 | 0.299615 | 0.613749 | 0.287431 | 0.496784 |
| 4 | NN 5 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
| 5 | NN 6 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
| 6 | NN 7 | 0.2 | 0.560492 | 0.582012 | 0.800698 | 0.306562 | 0.618289 | 0.326269 | 0.505399 |
Savepoint¶
nn1.save('/content/gdrive/My Drive/DSA Assignment Data/nn1.keras')
nn2.save('/content/gdrive/My Drive/DSA Assignment Data/nn2.keras')
nn3.save('/content/gdrive/My Drive/DSA Assignment Data/nn3.keras')
nn4.save('/content/gdrive/My Drive/DSA Assignment Data/nn4.keras')
nn5.save('/content/gdrive/My Drive/DSA Assignment Data/nn5.keras')
nn6.save('/content/gdrive/My Drive/DSA Assignment Data/nn6.keras')
nn7.save('/content/gdrive/My Drive/DSA Assignment Data/nn7.keras')
model_metrics_df.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/model_metrics_df.pkl')
Loadpoint¶
nn1 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn1.keras')
nn2 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn2.keras')
nn3 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn3.keras')
nn4 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn4.keras')
nn5 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn5.keras')
nn6 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn6.keras')
nn7 = load_model('/content/gdrive/My Drive/DSA Assignment Data/nn7.keras')
model_metrics_df = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/model_metrics_df.pkl')
X_training_best = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_training_best.pkl')
X_val_best = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_val_best.pkl')
X_test_best = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/X_test_best.pkl')
Q2e Step 12 - Final model (Iteration 5)¶
Comparison of All Models (Iterations 1 to 5)¶
# Q2e Step 12: Model construction
# Define models and labels
models = [nn1, nn2, nn3, nn4, nn5]
model_names = ['NN 1', 'NN 2', 'NN 3', 'NN 4', 'NN 5']
# Create an empty list to store confusion summaries
all_confusion_summaries = []
for model, name in zip(models, model_names):
if name in ['NN 1', 'NN 2']:
X_val_input = X_val
else:
X_val_input = X_val_best
if name in ['NN 1', 'NN 2']:
threshold = 0.4
else:
threshold = 0.2
p_va = model.predict(X_val_input, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
sum_va = confusion_summary(y_va, p_va, threshold)
# Add model name and append to list
sum_va['Model Name'] = name
all_confusion_summaries.append(sum_va)
# Concatenate all summaries into a single DataFrame
combined_confusion_summaries = pd.concat(all_confusion_summaries, ignore_index = True)
# Reorder columns to have Model Name first
cols = ['Model Name'] + [col for col in combined_confusion_summaries.columns if col != 'Model Name']
combined_confusion_summaries = combined_confusion_summaries[cols]
# Display model_metrics_df where model name in model_names
print("Combined Model Metrics:")
display(model_metrics_df[model_metrics_df['Model Name'].isin(model_names)].style.format({
'Threshold': "{:.2f}",
'Validation Loss': "{:.4f}",
'Misclassification Cost': "{:.4f}",
'Recall': "{:.4f}",
'Precision': "{:.4f}",
'ROC_AUC': "{:.4f}",
'Optimal Threshold': "{:.2f}",
'Optimal Misclassification Cost': "{:.4f}"
}))
# Display the combined confusion summaries
print("\nCombined Confusion Matrix Summaries:")
display(combined_confusion_summaries.style.format({
'Threshold': "{:.2f}",
'Recall': "{:.4f}",
'Precision': "{:.4f}"
}))
# Plot the metrics
plt.figure(figsize = (8, 5))
plt.plot(model_metrics_df['Model Name'], model_metrics_df['Misclassification Cost'], marker = 'o', label = 'Misclassification Cost')
plt.plot(model_metrics_df['Model Name'], model_metrics_df['Recall'], marker = 'o', label = 'Recall')
plt.plot(model_metrics_df['Model Name'], model_metrics_df['Precision'], marker = 'o', label = 'Precision')
plt.plot(model_metrics_df['Model Name'], model_metrics_df['ROC_AUC'], marker = 'o', label = 'ROC_AUC'),
plt.plot(model_metrics_df['Model Name'], model_metrics_df['Optimal Misclassification Cost'], marker = 'o', label = 'Optimal Misclassification Cost')
plt.title('Model Performance')
plt.xlabel('Model')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
# Show validation vs training losses
plot_training_val_loss(nn_hist_1, title = 'Training vs. Validation Loss (NN 1)')
plot_training_val_loss(nn_hist_2, title = 'Training vs. Validation Loss (NN 2)')
plot_training_val_loss(nn_hist_3, title = 'Training vs. Validation Loss (NN 3)')
plot_training_val_loss(nn_hist_4, title = 'Training vs. Validation Loss (NN 4)')
plot_training_val_loss(nn_hist_5, title = 'Training vs. Validation Loss (NN 5)')
Combined Model Metrics:
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.40 | 0.5674 | 0.5145 | 0.0741 | 0.4218 | 0.5967 | 0.40 | 0.5134 |
| 1 | NN 2 | 0.40 | 0.6740 | 0.6250 | 0.8356 | 0.2902 | 0.5987 | 0.66 | 0.5090 |
| 2 | NN 3 | 0.20 | 0.5607 | 0.5980 | 0.8229 | 0.3004 | 0.6145 | 0.38 | 0.5064 |
| 3 | NN 4 | 0.20 | 0.5610 | 0.5991 | 0.8138 | 0.2996 | 0.6137 | 0.29 | 0.4968 |
| 4 | NN 5 | 0.20 | 0.5605 | 0.5820 | 0.8007 | 0.3066 | 0.6183 | 0.33 | 0.5054 |
Combined Confusion Matrix Summaries:
| Model Name | N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 8706 | 0.40 | 6180 | 233 | 2123 | 170 | 0.0741 | 0.4218 |
| 1 | NN 2 | 8706 | 0.40 | 1726 | 4687 | 377 | 1916 | 0.8356 | 0.2902 |
| 2 | NN 3 | 8706 | 0.20 | 2019 | 4394 | 406 | 1887 | 0.8229 | 0.3004 |
| 3 | NN 4 | 8706 | 0.20 | 2051 | 4362 | 427 | 1866 | 0.8138 | 0.2996 |
| 4 | NN 5 | 8706 | 0.20 | 2260 | 4153 | 457 | 1836 | 0.8007 | 0.3066 |
# Show decile charts
plot_decile_chart(nn1, X_val, Y_val, model_name = 'NN 1')
plot_decile_chart(nn2, X_val, Y_val, model_name = 'NN 2')
plot_decile_chart(nn3, X_val_best, Y_val, model_name = 'NN 3')
plot_decile_chart(nn4, X_val_best, Y_val, model_name = 'NN 4')
plot_decile_chart(nn5, X_val_best, Y_val, model_name = 'NN 5')
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
273/273 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
Comparison of Final Model to First Model (Iterations 1 and 5)¶
# Q2e Step 12: Analysis of model behaviours and drivers
p_va = nn1.predict(X_val, verbose = 0).ravel()
y_va = np.asarray(Y_val).astype(int).ravel()
sum_va = confusion_summary(y_va, p_va, 0.2)
sum_va['Model Name'] = 'NN 1'
cols = ['Model Name'] + [col for col in sum_va.columns if col != 'Model Name']
sum_va = sum_va[cols]
confusion_summaries_nn1_nn5 = pd.concat([combined_confusion_summaries, sum_va], ignore_index = True)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_star, val_cost, curve_df, roc_df, auc = choose_threshold(
y_true = y_va, y_prob = p_va,
w_fn = W_FN, w_fp = W_FP,
plot = False,
return_curve = True,
plot_roc = False,
return_roc_curve = True,
model_name = 'NN 1'
)
threshold = 0.2
y_val_probs = nn1.predict(X_val, verbose = 0).ravel()
y_val_preds = (y_val_probs >= threshold).astype(int)
cost_sens_error = cost_sensitive_error(Y_val, y_val_preds, W_FN, W_FP)
nn1_metrics = {
'Model Name': 'NN 1',
'Threshold': sum_va['Threshold'].iloc[0],
'Validation Loss': min(nn_hist_1.history['val_loss']),
'Misclassification Cost': cost_sens_error,
'Recall': sum_va['Recall'].iloc[0],
'Precision': sum_va['Precision'].iloc[0],
'ROC_AUC': auc,
'Optimal Threshold': thr_star,
'Optimal Misclassification Cost': val_cost
}
# Append the new row to the combined results
model_metrics_df_nn1_nn5 = pd.concat([model_metrics_df, pd.DataFrame([nn1_metrics])], ignore_index = True)
# Display the rows in combined model metrics of NN 1 and NN 5
print("Combined Model Metrics of NN 1 and NN5:")
display(model_metrics_df_nn1_nn5[model_metrics_df_nn1_nn5['Model Name'].isin(['NN 1', 'NN 5'])].sort_values(by = ['Model Name', 'Threshold'], ascending = [True, False]).reset_index(drop = True).style.format({
'Threshold': "{:.2f}",
'Validation Loss': "{:.4f}",
'Misclassification Cost': "{:.4f}",
'Recall': "{:.4f}",
'Precision': "{:.4f}",
'ROC_AUC': "{:.4f}",
'Optimal Threshold': "{:.2f}",
'Optimal Misclassification Cost': "{:.4f}"
}))
# Display the rows in combined confusion summaries of NN 1 and NN 5
print("\nConfusion Matrix Summaries of NN 1 and NN 5:")
display(confusion_summaries_nn1_nn5[confusion_summaries_nn1_nn5['Model Name'].isin(['NN 1', 'NN 5'])].sort_values(by = ['Model Name', 'Threshold'], ascending = [True, False]).reset_index(drop = True).style.format({
'Threshold': "{:.2f}",
'Recall': "{:.4f}",
'Precision': "{:.4f}"
}))
Combined Model Metrics of NN 1 and NN5:
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 0.40 | 0.5674 | 0.5145 | 0.0741 | 0.4218 | 0.5967 | 0.40 | 0.5134 |
| 1 | NN 1 | 0.20 | 0.5674 | 0.5953 | 0.7867 | 0.3002 | 0.5967 | 0.40 | 0.5134 |
| 2 | NN 5 | 0.20 | 0.5605 | 0.5820 | 0.8007 | 0.3066 | 0.6183 | 0.33 | 0.5054 |
Confusion Matrix Summaries of NN 1 and NN 5:
| Model Name | N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 1 | 8706 | 0.40 | 6180 | 233 | 2123 | 170 | 0.0741 | 0.4218 |
| 1 | NN 1 | 8706 | 0.20 | 2208 | 4205 | 489 | 1804 | 0.7867 | 0.3002 |
| 2 | NN 5 | 8706 | 0.20 | 2260 | 4153 | 457 | 1836 | 0.8007 | 0.3066 |
Checks on Final Model¶
# Q2e Step 12: Analysis of model behaviours and drivers
# Get predicted probabilities
y_test_probs = nn5.predict(X_test_best).ravel()
# Get predicted class labels using threshold = 0.2
y_test_preds = (y_test_probs >= 0.2).astype(int)
# Plot AvE graph
plot_actual_vs_expected(Y_test, y_test_probs, n_bins = 10, model_name = 'NN 5')
# Check unique predicted classes
unique_classes, counts = np.unique(y_test_preds, return_counts = True)
for cls, count in zip(unique_classes, counts):
print(f"Class {cls}: {count} instances")
# Check min and max of predicted probabilities
print("\nMin predicted probability (on test data):", y_test_probs.min())
print("Max predicted probability (on test data):", y_test_probs.max())
# Check if all values are between 0 and 1
if (y_test_probs >= 0).all() and (y_test_probs <= 1).all():
print("\nAll predicted probabilities are within the range [0, 1].")
else:
print("\nSome predicted probabilities are outside the valid range.")
305/305 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
Class 0: 3083 instances Class 1: 6666 instances Min predicted probability (on test data): 0.11427789 Max predicted probability (on test data): 0.45921242 All predicted probabilities are within the range [0, 1].
Summary¶
Q2f - Interpret performance¶
Q2f Step 1 - Performance of Chosen Model¶
Confusion Matrix¶
# Q2f Step 1
p_te = nn5.predict(X_test_best, verbose = 0).ravel()
y_te = np.asarray(Y_test).astype(int).ravel()
# Show confusion matrix
cm_te = confusion_matrix_df(y_te, p_te, 0.2)
sum_te = confusion_summary(y_te, p_te, 0.2)
print("\nConfusion Matrix Summary on Test Data (0.2 threshold):\n")
display(sum_te)
plot_confusion_matrix(cm_te, title = 'Confusion Matrix on Test Data (0.2 threshold) (NN 5)')
plt.show()
Confusion Matrix Summary on Test Data (0.2 threshold):
| N | Threshold | TN | FP | FN | TP | Recall | Precision | |
|---|---|---|---|---|---|---|---|---|
| 0 | 9749 | 0.2 | 2680 | 4574 | 403 | 2092 | 0.838477 | 0.313831 |
Misclassification Cost, Recall, Precision, and ROC_AUC¶
# Q2f Step 1
# Show misclassification cost, recall, precision, and ROC_AUC of test data
# Business weights (i.e. model-agnostic)
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
thr_out, test_cost, curve_df, roc_df, auc = plot_metrics_test(
y_true = y_te, y_prob = p_te,
w_fn = W_FN, w_fp = W_FP,
threshold = 0.20,
plot = True,
return_curve = True,
plot_roc = True,
return_roc_curve = True,
title_cost = 'Performance Metrics on Test Data (NN 5)',
model_name = 'NN 5'
)
print(f"Misclassification cost: {test_cost:.4f} | AUC: {auc:.4f}")
Misclassification cost: 0.5519 | AUC: 0.6500
Summary¶
# Q2f Step 1
test_loss, test_precision, test_recall, test_auc = nn5.evaluate(X_test_best, Y_test, verbose = 0)
# Create a new row with the metrics for NN 6
nn5_metrics_test = {
'Model Name': 'NN 5',
'Threshold': sum_te['Threshold'].iloc[0],
'Test Loss': test_loss,
'Misclassification Cost': test_cost,
'Recall': sum_te['Recall'].iloc[0],
'Precision': sum_te['Precision'].iloc[0],
'ROC_AUC': auc
}
# Display the validation and test results
print("Validation Metrics:")
display(model_metrics_df[model_metrics_df['Model Name'] == 'NN 5'].reset_index(drop = True).style.format({
'Threshold': "{:.2f}",
'Validation Loss': "{:.4f}",
'Misclassification Cost': "{:.4f}",
'Recall': "{:.4f}",
'Precision': "{:.4f}",
'ROC_AUC': "{:.4f}",
'Optimal Threshold': "{:.2f}",
'Optimal Misclassification Cost': "{:.4f}"
}))
print("Test Metrics:")
display(pd.DataFrame([nn5_metrics_test]).style.format({
'Threshold': "{:.2f}",
'Test Loss': "{:.4f}",
'Misclassification Cost': "{:.4f}",
'Precision': "{:.4f}",
'Recall': "{:.4f}",
'ROC_AUC': "{:.4f}"
}))
Validation Metrics:
| Model Name | Threshold | Validation Loss | Misclassification Cost | Recall | Precision | ROC_AUC | Optimal Threshold | Optimal Misclassification Cost | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | NN 5 | 0.20 | 0.5605 | 0.5820 | 0.8007 | 0.3066 | 0.6183 | 0.33 | 0.5054 |
Test Metrics:
| Model Name | Threshold | Test Loss | Misclassification Cost | Recall | Precision | ROC_AUC | |
|---|---|---|---|---|---|---|---|
| 0 | NN 5 | 0.20 | 0.5443 | 0.5519 | 0.8385 | 0.3138 | 0.6500 |
Q2f Step 2 - Comparison to Benchmarks (Logistic Regression and Shallow Decision Tree)¶
# Q2f Step 2
# Relative cost of missing a true future acute case
W_FN = 2.0
# Relative cost of flagging a non-acute case as acute
W_FP = 1.0
def metrics_at_threshold(y_true, proba, thr, cost_fn=W_FN, cost_fp=W_FP):
pred = (proba >= thr).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
# Threshold-based metrics
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
# Misclassification cost
cost_total = cost_fn * fn + cost_fp * fp
cost_per_sample = cost_total / len(y_true)
# ROC_AUC
roc_auc = roc_auc_score(y_true, proba)
return {
"threshold": thr,
"ROC_AUC": roc_auc,
"Recall": recall,
"Precision": precision,
"MisclassificationCost": float(cost_per_sample),
"TP": int(tp), "FP": int(fp), "TN": int(tn), "FN": int(fn),
}
def evaluate_model(pipe, name, X_val, y_val, X_test, y_test, threshold=0.2):
proba_val = pipe.predict_proba(X_val)[:, 1]
val_metrics = metrics_at_threshold(y_val, proba_val, threshold)
proba_test = pipe.predict_proba(X_test)[:, 1]
test_metrics = metrics_at_threshold(y_test, proba_test, threshold)
# Tidy table
row = {
"Model Name": name,
"Threshold": threshold,
"Misclassification Cost (Val)": val_metrics["MisclassificationCost"],
"Recall (Val)": val_metrics["Recall"],
"Precision (Val)": val_metrics["Precision"],
"ROC_AUC (Val)": val_metrics["ROC_AUC"],
"Misclassification Cost (Test)": test_metrics["MisclassificationCost"],
"Recall (Test)": test_metrics["Recall"],
"Precision (Test)": test_metrics["Precision"],
"ROC_AUC (Test)": test_metrics["ROC_AUC"]
}
return row, val_metrics, test_metrics
# 1) Logistic Regression
benchmark_logit = LogisticRegression(
solver = 'liblinear',
max_iter = 200
).fit(X_training_best, Y_training)
# 2) Shallow Decision Tree
benchmark_tree = DecisionTreeClassifier(
max_depth = 3, # keep it shallow to stay "simple"
min_samples_leaf = 100, # stabilizes thresholds on tabular clinical data
).fit(X_training_best, Y_training)
# Calculate benchmark metrics
benchmark_logit_row, _, _ = evaluate_model(benchmark_logit, "Log Reg", X_val_best, Y_val, X_test_best, Y_test)
benchmark_tree_row, _, _ = evaluate_model(benchmark_tree, "Shallow Tree", X_val_best, Y_val, X_test_best, Y_test)
# Display results
print("Summary of Benchmark Metrics:")
display(pd.DataFrame([benchmark_logit_row, benchmark_tree_row]))
# Show results of benchmarks with chosen NN model
nn_benchmarks_comparison = {
'Model Name': 'NN 5',
'Threshold': sum_te['Threshold'].iloc[0],
'Misclassification Cost': test_cost,
'Recall': sum_te['Recall'].iloc[0],
'Precision': sum_te['Precision'].iloc[0],
'ROC_AUC': auc
}
benchmark_logit_df = pd.DataFrame([{
'Model Name': benchmark_logit_row['Model Name'],
'Threshold': benchmark_logit_row['Threshold'],
'Misclassification Cost': benchmark_logit_row['Misclassification Cost (Test)'],
'Recall': benchmark_logit_row['Recall (Test)'],
'Precision': benchmark_logit_row['Precision (Test)'],
'ROC_AUC': benchmark_logit_row['ROC_AUC (Test)']
}])
benchmark_tree_df = pd.DataFrame([{
'Model Name': benchmark_tree_row['Model Name'],
'Threshold': benchmark_tree_row['Threshold'],
'Misclassification Cost': benchmark_tree_row['Misclassification Cost (Test)'],
'Recall': benchmark_tree_row['Recall (Test)'],
'Precision': benchmark_tree_row['Precision (Test)'],
'ROC_AUC': benchmark_tree_row['ROC_AUC (Test)']
}])
# Concatenate the DataFrames
nn_benchmarks_comparison = pd.concat([pd.DataFrame([nn_benchmarks_comparison]), benchmark_logit_df, benchmark_tree_df], ignore_index = True)
# Display the combined DataFrame
print("\nCombined NN and Benchmark Metrics (on test data):")
display(nn_benchmarks_comparison)
plt.figure(figsize = (16, 5))
plot_tree(benchmark_tree, feature_names = list(X_training_best.columns), class_names = ["no_acute", "acute"], filled = True, max_depth = 3, fontsize = 8)
plt.show()
Summary of Benchmark Metrics:
| Model Name | Threshold | Misclassification Cost (Val) | Recall (Val) | Precision (Val) | ROC_AUC (Val) | Misclassification Cost (Test) | Recall (Test) | Precision (Test) | ROC_AUC (Test) | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Log Reg | 0.2 | 0.609350 | 0.814217 | 0.295411 | 0.606591 | 0.568981 | 0.859719 | 0.306779 | 0.662537 |
| 1 | Shallow Tree | 0.2 | 0.605445 | 0.830353 | 0.297640 | 0.607903 | 0.585188 | 0.851303 | 0.299704 | 0.631815 |
Combined NN and Benchmark Metrics (on test data):
| Model Name | Threshold | Misclassification Cost | Recall | Precision | ROC_AUC | |
|---|---|---|---|---|---|---|
| 0 | NN 5 | 0.2 | 0.551851 | 0.838477 | 0.313831 | 0.650032 |
| 1 | Log Reg | 0.2 | 0.568981 | 0.859719 | 0.306779 | 0.662537 |
| 2 | Shallow Tree | 0.2 | 0.585188 | 0.851303 | 0.299704 | 0.631815 |
Q2f Step 3 - Final Words for Sharing with Betahelf¶
Performance of our model
To evaluate the performance, we have chosen an operational threshold of 0.20. This means that a patient with a probability of developing an acute diagnosis over the next 12 months of at least 0.20 will be flagged by our model as receiving an acute diagnosis over the next 12 months.
- A lower threshold was set because acute diagnoses can be rare, but yet when they happen, they can be significant (i.e. low-frequency, high-severity events). Thus, a lower threshold implies a more conservative, "better to be safe than sorry" approach.
Our chosen model has a recall of 0.84. This means that we correctly identify roughly 84% of all patients who do go on to receive an acute diagnosis, which is crucial because missing them can have devastating reputational harm on Betahelf along with adverse consequences for the patients.
Our chosen model has a precision of 0.31. This implies roughly 2 false alerts for every true acute case, which is significant, but operationally workable for Betahelf's proactive outreach teams given the upside of catching most at-risk patients.
Comparison to benchmarks
To highlight its rigour and performance, we have compared our chosen model to two benchmark models - a logistic regression and a shallow decision tree.
Both of these benchmarks incur higher costs than our chosen model, where cost is defined with respect to a false negative (i.e. missing a patient who will truly become acute) and a false positive (i.e. flagging a patient who turns out not to be acute). In this way, cost is viewed from a lens that is most relevant to Betahelf's business context.
With competitive performance, the logistic regression model offers a robust baseline, but at our chosen operating threshold, it triggers more expensive errors.
The shallow decision tree has weaker overall performance and the highest cost among the three models at our chosen operating threshold, but offers clinically transparent and interpretable rules for decision-making.
Key takeaways for Betahelf
Our chosen model is best aligned to Betahelf's objective of predicting which of its current patients will receive an acute diagnosis over the next 12 months as it is the most cost-effective compared to the benchmarks.
- Our model prioritises detecting at-risk patients while maintaining a sustainable level of false alarms. This directly supports Betahelf's aim of early intervention without overwhelming healthcare teams.
The logistic regression model is a strong back-up in case our model fails. Betahelf should keep it in its monitoring suite as a governance baseline. Having a spare working model is also good practice.
If outreach bandwidth tightens, consider raising the operating threshold to lower the number of false alerts. If prevention of misses of acute cases becomes paramount (e.g. due to seasonal spikes), lower the operating threshold modestly.
As with all models, Betahelf should use our chosen model with caution and only as a decision support tool (i.e. it should not use the model to automatically make any decisions).
Q2g - Interpret behaviour¶
Q2g Step 1 - Feature Importance¶
# Q2g Step 1
# Select a small sample to speed up SHAP calculation
X_sample = X_test_best.sample(n = 1000, random_state = 0)
# Explainer for TensorFlow/Keras model
explainer = shap.Explainer(nn5, X_sample)
# Compute SHAP values
shap_values = explainer(X_sample)
def shap_plot_to_array(shap_plot_func, *args, **kwargs):
"""Render SHAP plot to numpy array (RGB image)."""
fig = plt.figure()
shap_plot_func(*args, show=False, **kwargs)
fig.canvas.draw()
# Get renderer and RGBA buffer
renderer = fig.canvas.get_renderer()
raw_data = renderer.buffer_rgba()
width, height = fig.canvas.get_width_height()
# Convert RGBA to RGB image
img = np.frombuffer(raw_data, dtype=np.uint8).reshape((height, width, 4))
img_rgb = img[:, :, :3]
plt.close(fig)
return img_rgb
# Generate SHAP image arrays (top 15 features)
img_dot = shap_plot_to_array(
shap.summary_plot,
shap_values,
features=X_sample,
feature_names=X_sample.columns,
plot_type="dot",
max_display=15
)
img_bar = shap_plot_to_array(
shap.summary_plot,
shap_values,
features=X_sample,
feature_names=X_sample.columns,
plot_type="bar",
max_display=15
)
# Plot side by side
fig, axes = plt.subplots(1, 2, figsize = (18, 6))
axes[0].imshow(img_dot)
axes[0].axis('off')
axes[0].set_title("SHAP Summary Plot (Top 15 Features - Dot)")
axes[1].imshow(img_bar)
axes[1].axis('off')
axes[1].set_title("SHAP Summary Plot (Top 15 Features - Bar)")
plt.tight_layout()
plt.show()
PermutationExplainer explainer: 1001it [00:35, 21.42it/s]
Q2g Step 2 - Partial Dependence¶
# Q2g Step 2
def plot_pdp(feature, model, X_reference, ax, grid_resolution=50, use_mean_baseline=True):
"""
Plot a Partial Dependence Plot (PDP) for a specified feature into the given axis.
"""
# Step 1: Create a grid of values for the target feature
feature_range = np.linspace(
X_reference[feature].min(),
X_reference[feature].max(),
grid_resolution
)
# Step 2: Create a feature grid (with either mean or zero for other features)
if use_mean_baseline:
X_grid = pd.DataFrame([X_reference.mean().to_dict()] * grid_resolution)
else:
X_grid = pd.DataFrame(np.zeros((grid_resolution, X_reference.shape[1])), columns=X_reference.columns)
# Step 3: Vary only the selected feature
X_grid[feature] = feature_range
# Step 4: Make predictions
y_preds = model.predict(X_grid).flatten()
# Step 5: Plot into provided axis
ax.plot(feature_range, y_preds, marker='o', linestyle='-')
ax.set_xlabel(feature)
ax.set_ylabel('Acute Diagnosis Probability')
ax.set_title(f'{feature}')
ax.grid(True)
# Top 5 features
top5_features = [
'unique_ICD9CodesDesc_p12m',
'AgeLastBirthday',
'HospitalCountPer100k',
'BelowPovertyLevel',
'MedianMAP_p12m'
]
# Create 2x3 subplot grid
fig, axes = plt.subplots(2, 3, figsize = (20, 10))
# Flatten axes for easier indexing
axes = axes.flatten()
# Plot PDPs for top 5 features
for i, feature in enumerate(top5_features):
plot_pdp(feature, model=nn5, X_reference=X_test_best, ax=axes[i])
# Hide the last (unused) subplot
axes[5].axis('off')
plt.tight_layout()
plt.show()
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
Q2g Step 3 - SHAP Explanations¶
# Q2g Step 3
# Step 1: Predict acute diagnosis probabilities
y_test_probs = nn5.predict(X_test_best).ravel()
# Step 2: Identify indices
high_pred = np.argmax(y_test_probs) # Highest acute probability
low_pred = np.argmin(y_test_probs) # Lowest acute probability
border_pred = np.argmin(np.abs(y_test_probs - 0.2)) # Closest to 0.2 threshold
# Step 3: Select these three samples
selected_indices = [high_pred, low_pred, border_pred]
X_explain = X_test_best.iloc[selected_indices]
# Step 4: SHAP explanation
explainer = shap.Explainer(nn5, X_test_best)
shap_values = explainer(X_explain)
# Step 5: Plot each case separately and show the top 10 features for each case
labels = ["Highest Likelihood of Acute Diagnosis", "Lowest Likelihood of Acute Diagnosis", "Borderline Acute Case"]
for i, pred in enumerate(selected_indices):
prob = y_test_probs[pred]
print(f"{labels[i]} (Predicted Acute Diagnosis Probability: {prob:.4f}):")
shap.plots.waterfall(shap_values[i], max_display = 10)
305/305 ━━━━━━━━━━━━━━━━━━━━ 0s 399us/step Highest Likelihood of Acute Diagnosis (Predicted Acute Diagnosis Probability: 0.4592):
Lowest Likelihood of Acute Diagnosis (Predicted Acute Diagnosis Probability: 0.1143):
Borderline Acute Case (Predicted Acute Diagnosis Probability: 0.2000):
Q2g Step 4 - Fairness Metrics¶
# Q2g Step 4
# Analyse 'BelowPovertyLevel' for possible socioeconomic bias
# Analyse 'Gender' for possible gender bias
# Step 1: Calculate SHAP values on a sample to speed things up
explainer = shap.Explainer(nn5, X_sample)
shap_values_sample = explainer(X_sample)
# Step 2: Find observations with highest and lowest SHAP values for 'BelowPovertyLevel' and 'Gender'
feature_below_poverty = 'BelowPovertyLevel'
feature_gender = 'Gender'
# Get column indices from the sample DataFrame
col_idx_below_poverty = X_sample.columns.get_loc(feature_below_poverty)
col_idx_gender = X_sample.columns.get_loc(feature_gender)
# Get SHAP values for the specific features from the sample SHAP values
shap_below_poverty = shap_values_sample.values[:, col_idx_below_poverty]
shap_gender = shap_values_sample.values[:, col_idx_gender]
# Find the indices of max and min SHAP values
idx_below_poverty_high_shap_sample = np.argmax(shap_below_poverty)
idx_below_poverty_low_shap_sample = np.argmin(shap_below_poverty)
idx_gender_high_shap_sample = np.argmax(shap_gender)
idx_gender_low_shap_sample = np.argmin(shap_gender)
# Step 3: Select these specific instances from the sample
selected_indices_sample = [
idx_below_poverty_high_shap_sample,
idx_below_poverty_low_shap_sample,
idx_gender_high_shap_sample,
idx_gender_low_shap_sample
]
X_explain_specific_sample = X_sample.iloc[selected_indices_sample]
shap_values_explain_specific_sample = shap_values_sample[selected_indices_sample]
# Step 4: Get the predicted probabilities for the selected samples
corresponding_probs = nn5.predict(X_explain_specific_sample).ravel()
# Step 5: Plot waterfall charts for each selected instance from the sample
labels = [
f"Highest SHAP for {feature_below_poverty} (from sample)",
f"Lowest SHAP for {feature_below_poverty} (from sample)",
f"Highest SHAP for {feature_gender} (from sample)",
f"Lowest SHAP for {feature_gender} (from sample)"
]
for i in range(len(selected_indices_sample)):
prob = corresponding_probs[i]
print(f"\n{labels[i]} (Predicted Acute Diagnosis Probability: {prob:.4f}):")
shap.plots.waterfall(shap_values_explain_specific_sample[i], max_display = 10)
PermutationExplainer explainer: 1001it [00:33, 20.81it/s]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step
Highest SHAP for BelowPovertyLevel (from sample) (Predicted Acute Diagnosis Probability: 0.3587):
Lowest SHAP for BelowPovertyLevel (from sample) (Predicted Acute Diagnosis Probability: 0.2612):
Highest SHAP for Gender (from sample) (Predicted Acute Diagnosis Probability: 0.3320):
Lowest SHAP for Gender (from sample) (Predicted Acute Diagnosis Probability: 0.2435):
Q2g Step 5 - Final Words for Sharing with Betahelf¶
How the model makes its predictions
Our model's predictions tell a consistent story. It leverages most on signals of clinical burden and recent activity, complemented by state-level access/socioeconomic context.
- In plain terms, patients with more abnormal lab results in the past 12 months, a broader mix of diagnoses and specialty encounters (i.e. more unique ICD-9 code descriptions), and younger ages are predicted with higher probabilities of receiving an acute diagnosis over the next 12 months. Context variables, such as hospital count and percentage of the state's population below the poverty level, also play an important part.
This aligns with real-world expectations as acute diagnoses are likely driven primarily by recent clinical instability (e.g. abnormal labs, many comorbidities) and complemented by environmental factors working in the background.
How a patient's risk of acute diagnosis changes when predictors move
As the 12-month historical mix of diagnoses and specialty encounters for a patient increases, the likelihood that the patient receives an acute diagnosis over the next 12 months increases.
- However, this relationship is not linear. At low numbers, increases in this mix increases the likelihood that the patient receives an acute diagnosis over the next 12 months much more rapidly than at higher numbers.
Patients who are very much older also have the probability of receiving an acute diagnosis over the next 12 months decrease much more rapidly for each additional age than younger patients.
Why a patient was flagged as acute
We examined the driving factors behind three representative patients in our model - the lowest acute probability, highest acute probability, and a borderline acute case, which are all intuitive "edge cases".
In particular, the borderline case of acute diagnosis had mixed signals, coming mostly from almost no diagnoses or specialty encounters, but partially offset with a younger age.
The above example is informative as borderline cases are very important to Betahelf and the analysis helps it understand the factors that "make or break" an acute case.
Together, these explanations confirm that the same global drivers are also locally responsible for individual flags, improving trust in the model.
Fairness checks
We found some potential socioeconomic bias as poverty levels are used in the model, which do have a moderate effect on predictions.
However, this can be mitigated if Betahelf only uses the model as a decision support tool (i.e. it should not use the model to automatically make any decisions) and consider the bias explicitly when making the final decisions.
Bottom line for Betahelf
Our chosen model is behaving as a clinician would expect - prioritising signs of recent clinical instability while taking into account environmental factors.
Its decisions are explainable at both the aggregate and individual patient levels, and our fairness probes suggest that the model is using contextual features sensibly. This supports confident deployment of the model for early, targeted interventions with ongoing governance to sustain equity and performance.
Q3 - AI Pathologists¶
Q3a - Python prototype¶
Q3a Step 1 - Build Prototype¶
# Q3a Step 1
# ============================ LLM Prompt Prototype ============================
# Prereq to run the model locally:
# pip install transformers accelerate torch --upgrade
# Zephyr-7B-alpha is a 7B model; device_map="auto" will try to use GPU if available.
# Assumes these DataFrames already exist:
# labresult_df, labobservation_df, pathology_df, patient_df, smoking_df
# 1) Sample LabResultGuids
def sample_labresult_guids(n: int = 5, seed: Optional[int] = None) -> List[str]:
if seed is not None:
random.seed(seed)
guids = labresult_df["LabResultGuid"].dropna().astype(str).unique().tolist()
return random.sample(guids, min(n, len(guids)))
# 2) Fetch Patient + Smoking (as-of) + LabObservations
def fetch_context_for_labresult(labresultguid: str) -> Dict:
lr = labresult_df[labresult_df["LabResultGuid"].astype(str) == str(labresultguid)]
lr_row = lr.iloc[0].to_dict()
# link to patient
pt_key = "PatientGuid"
patient_guid = lr_row[pt_key]
# patient record
p = patient_df[patient_df[pt_key] == patient_guid]
patient = p.iloc[0].to_dict() if not p.empty else {}
# smoking as-of (if time slices exist); fall back to most recent
ts = lr_row.get("Timestamp", None)
sm = smoking_df[smoking_df[pt_key] == patient_guid].copy()
smoking = {}
if not sm.empty:
# try to align by ValidFrom/ValidTo if those columns exist
if "ValidFrom" in sm.columns:
sm["ValidFrom"] = pd.to_datetime(sm["ValidFrom"], errors="coerce")
if "ValidTo" in sm.columns:
sm["ValidTo"] = pd.to_datetime(sm["ValidTo"], errors="coerce")
if ts is not None:
try:
ts = pd.to_datetime(ts, errors="coerce")
except Exception:
pass
if (ts is not None) and pd.notnull(ts) and {"ValidFrom","ValidTo"}.issubset(sm.columns):
asof = sm[(sm["ValidFrom"] <= ts) & (sm["ValidTo"].isna() | (sm["ValidTo"] >= ts))]
if asof.empty:
asof = sm.sort_values("ValidFrom", na_position="first").tail(1)
smoking = asof.iloc[0].to_dict()
else:
sort_col = "ValidFrom" if "ValidFrom" in sm.columns else sm.columns[0]
smoking = sm.sort_values(sort_col, na_position="first").tail(1).iloc[0].to_dict()
# Lab observations for this LabResultGuid excluding first two columns
lo_raw = labobservation_df[labobservation_df["LabResultGuid"].astype(str) == str(labresultguid)].copy()
lo_raw = lo_raw.loc[:, lo_raw.columns[2:]]
return {"labresult": lr_row, "patient": patient, "smoking": smoking, "lab_observations_raw": lo_raw}
# 3) Build prompt (try with EXAMPLE OF OUTPUT)
EXAMPLE_OUTPUT = """**Pathology Report**
**Patient Information:**
- Name: Gabriel Warwick
- Date of Birth: March 27, 1994
- Smoking Status: 0 cigarettes per day (previous smoker)
---
**Blood Test Results:**
1. Hematocrit: 38.3% (Reference Range: 36.0-50.0)
- Result: Within normal range
2. Immature Granulocytes: 74.2 g/dL
- Result: Abnormal, further investigation may be required
3. Ketones: Not Available
- Result: Not available for analysis
---
**Comments:**
- The Immature Granulocytes level of 74.2 g/dL is abnormal and may indicate an ongoing infection or inflammation. Further evaluation and monitoring may be needed.
- The patient's previous smoking status may have contributed to the abnormal blood test results. It is important to address smoking cessation and its potential impact on overall health.
---
**Recommendations:**
1. Follow-up testing to monitor the Immature Granulocytes levels and investigate possible underlying conditions.
2. Encourage lifestyle modifications, including smoking cessation, to improve overall health and potentially normalize future blood test results.
"""
def build_prompt(labresultguid: str) -> str:
ctx = fetch_context_for_labresult(labresultguid)
p, s, lo_raw = ctx["patient"], ctx["smoking"], ctx["lab_observations_raw"]
# simple name/date
def patient_name(p):
given = p.get("GivenName") or p.get("FirstName") or ""
sur = p.get("Surname") or p.get("LastName") or ""
full = f"{str(given).strip()} {str(sur).strip()}".strip()
return full or "Unknown"
name = patient_name(p) if p else "Unknown"
dob = p.get("DateOfBirth")
try:
dob_str = pd.to_datetime(dob).strftime("%B %d, %Y") if pd.notnull(dob) else "Unknown"
except Exception:
dob_str = str(dob) if dob is not None else "Unknown"
smoke_desc = s.get("Description") if s else None
smoke_str = str(smoke_desc) if smoke_desc is not None else "Unknown"
# RAW subset → CSV string.
lo_sub = lo_raw
lo_csv = "No blood test results provided for analysis."
if not lo_sub.empty:
lo_csv = lo_sub.to_csv(index=False)
return f"""CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: {name}
- Date of Birth: {dob_str}
- Smoking Status: {smoke_str}
- Table of Lab Observations (CSV rows):
{lo_csv}
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as “Pathology Report“ with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
"""
# 4) Build prompts for 5 random LabResultGuids
def build_prompts_df(guids: Optional[List[str]] = None, n_random: int = 5, seed: Optional[int] = None) -> pd.DataFrame:
if guids is None:
guids = sample_labresult_guids(n=n_random, seed=seed)
path_map = pathology_df.set_index("LabResultGuid")["written_report"].astype(str)
rows = []
for g in guids:
wr = path_map.get(g, None)
rows.append({"LabResultGuid": g, "written_report": wr, "Prompt": build_prompt(g)})
return pd.DataFrame(rows)
# 5) Run Zephyr-7B-alpha locally via Transformers pipeline
def run_zephyr(prompts_df: pd.DataFrame,
max_new_tokens: int = 768, temperature: float = 0.7,
top_k: int = 50, top_p: float = 0.95) -> pd.DataFrame:
# Create a single pipeline and reuse it for all prompts (faster).
pipe = pipeline(
"text-generation",
model="HuggingFaceH4/zephyr-7b-alpha",
torch_dtype=torch.bfloat16,
device_map="auto"
)
results = []
for i, row in prompts_df.iterrows():
prompt = row["Prompt"]
out = pipe(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
raw = out[0]["generated_text"]
results.append(raw)
prompts_df = prompts_df.copy()
prompts_df["ResultsFromLLM"] = results
return prompts_df
Q3a Step 2 - Run Prototype¶
# Q3a Step 2
# Build prompts for five random lab results, run them through zephyr, and store outputs
prompts_df = build_prompts_df(n_random=5, seed=42)
prompts_df = run_zephyr(prompts_df)
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
config.json: 0%| | 0.00/628 [00:00<?, ?B/s]
`torch_dtype` is deprecated! Use `dtype` instead!
model.safetensors.index.json: 0.00B [00:00, ?B/s]
Fetching 8 files: 0%| | 0/8 [00:00<?, ?it/s]
model-00001-of-00008.safetensors: 0%| | 0.00/1.89G [00:00<?, ?B/s]
model-00002-of-00008.safetensors: 0%| | 0.00/1.95G [00:00<?, ?B/s]
model-00007-of-00008.safetensors: 0%| | 0.00/1.98G [00:00<?, ?B/s]
model-00008-of-00008.safetensors: 0%| | 0.00/816M [00:00<?, ?B/s]
model-00004-of-00008.safetensors: 0%| | 0.00/1.95G [00:00<?, ?B/s]
model-00006-of-00008.safetensors: 0%| | 0.00/1.95G [00:00<?, ?B/s]
model-00005-of-00008.safetensors: 0%| | 0.00/1.98G [00:00<?, ?B/s]
model-00003-of-00008.safetensors: 0%| | 0.00/1.98G [00:00<?, ?B/s]
Loading checkpoint shards: 0%| | 0/8 [00:00<?, ?it/s]
generation_config.json: 0%| | 0.00/111 [00:00<?, ?B/s]
tokenizer_config.json: 0.00B [00:00, ?B/s]
tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]
tokenizer.json: 0.00B [00:00, ?B/s]
added_tokens.json: 0%| | 0.00/42.0 [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/168 [00:00<?, ?B/s]
Device set to use cuda:0
Savepoint¶
prompts_df.to_pickle('/content/gdrive/My Drive/DSA Assignment Data/prompts_df.pkl')
Loadpoint¶
prompts_df = pd.read_pickle('/content/gdrive/My Drive/DSA Assignment Data/prompts_df.pkl')
Q3a Step 3 - Checks on LLM Output¶
# Q3a Step 3
# Check the DataFrame to confirm that the prompts have been run through the model and there is output
display(prompts_df)
| LabResultGuid | written_report | Prompt | ResultsFromLLM | |
|---|---|---|---|---|
| 0 | fc641eaf-1dce-48bb-a004-5b2f7b93715b | **Pathology Report:**\nPatient: Julie Kemper\n... | CONTEXT\n\nYou are a pathologist writing conci... | CONTEXT\n\nYou are a pathologist writing conci... |
| 1 | dec4ba11-fecc-4ba3-9ef2-2ec4f6f2d9d3 | ## Pathology Report\n### Patient Information:\... | CONTEXT\n\nYou are a pathologist writing conci... | CONTEXT\n\nYou are a pathologist writing conci... |
| 2 | cf9823b9-3666-4fb9-a494-608d11ae74cd | **Pathology Report:**\n**Patient Information:*... | CONTEXT\n\nYou are a pathologist writing conci... | CONTEXT\n\nYou are a pathologist writing conci... |
| 3 | 75177378-be69-4fcb-9c3b-3981815bd9cb | **Pathology Report:**\n- **AST (SGOT) (Asparta... | CONTEXT\n\nYou are a pathologist writing conci... | CONTEXT\n\nYou are a pathologist writing conci... |
| 4 | f715878a-45d5-4f62-aa73-c21399cd89f3 | **Pathology Report:**\nPatient Name: Justin Fo... | CONTEXT\n\nYou are a pathologist writing conci... | CONTEXT\n\nYou are a pathologist writing conci... |
# Q3a Step 3
# Print the results from each LLM to check that there is sensible output
for i, row in prompts_df.iterrows():
print(f"Results from LLM for LabResultGuid - {row['LabResultGuid']} (Sample {i + 1}):")
print()
print(row['ResultsFromLLM'])
if i < 4:
print()
print("-" * 400)
print("-" * 400)
print()
Results from LLM for LabResultGuid - fc641eaf-1dce-48bb-a004-5b2f7b93715b (Sample 1):
CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: Cecile Wise
- Date of Birth: September 01, 1982
- Smoking Status: 0 cigaretttes per day (non-smoker or less than 100 in lifetime)
- Table of Lab Observations (CSV rows):
HL7Text,ObservationValue,Units,ReferenceRange,AbnormalFlags,IsAbnormalValue
Cannabinoids,0.7,g/dL,,,False
Thyroxine,,,,,False
Urobilinogen,5.2,g/dL,,,False
Triglyceride,84.0,mg/dL,1.8-7.8,,False
Urobilinogen,32.0,mg/dL,Yellow,Above Normal High,True
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as "Pathology Report" with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
Pathology Report
Patient Information
Name: Cecile Wise
Date of Birth: September 01, 1982
Smoking Status: 0 cigarettes per day (non-smoker or less than 100 in lifetime)
Blood Test Results
1. Cannabinoids: 0.7 g/dL
- Reference Range: N/A
- Note: Within range
2. Thyroxine:
- Reference Range: N/A
- Note: Within range
3. Urobilinogen:
- 5.2 g/dL
- Reference Range: 0.0-1.0 g/dL
- Note: Within range
4. Triglyceride: 84.0 mg/dL
- Reference Range: 1.8-7.8 mg/dL
- Note: Elevated
5. Urobilinogen:
- 32.0 mg/dL
- Reference Range: N/A
- Note: Abnormal (above normal high)
Commentary
The blood test results for Cecile Wise suggest potential liver dysfunction, as indicated by the elevated levels of urobilinogen. The urobilinogen level of 32.0 mg/dL is above normal high and may indicate a potential biliary obstruction. This risk factor should be further evaluated through additional liver function tests and imaging studies.
Recommendations
We recommend that Cecile Wise undergo liver function tests, including alkaline phosphatase and bilirubin, to further evaluate the potential liver dysfunction. She should also be advised to avoid excessive alcohol consumption and maintain a healthy diet and lifestyle. A follow-up appointment should be scheduled to discuss the results and determine the appropriate care management plan.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - dec4ba11-fecc-4ba3-9ef2-2ec4f6f2d9d3 (Sample 2):
CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: Fern Dugger
- Date of Birth: March 28, 1994
- Smoking Status: 0 cigaretttes per day (non-smoker or less than 100 in lifetime)
- Table of Lab Observations (CSV rows):
HL7Text,ObservationValue,Units,ReferenceRange,AbnormalFlags,IsAbnormalValue
Albumin / Globulin Ratio,1.8,,6.0-8.5,,False
Bilirubin,0.3,mg/dL,97-108,,False
"Chloride, Serum",104.0,mmol/L,3.5-5.2,,False
Globulin,2.5,g/dL,1.1-2.5,,False
Hematocrit,37.3,%,11.5-15.0,,False
Hemoglobin,12.1,g/dL,0-149,,False
Monocytes,0.6,x10E3/uL,140-415,Above Normal High,True
Neutrophils,1.7,x10E3/uL,0.1-1.0,,False
Platelet Count,145.0,x10E3/uL,1.5-4.5,,False
"Potassium, Serum",3.9,mmol/L,Negative,,False
Protein Total,6.7,g/dL,0.0-1.2,,False
Triglyceride,164.0,mg/dL,1.8-7.8,Above Normal High,True
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as "Pathology Report" with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
Pathology Report
Patient Information:
Name: Fern Dugger
Date of Birth: March 28, 1994
Smoking Status: 0 cigarettes per day (non-smoker or less than 100 in lifetime)
Blood Test Results:
1. Albumin / Globulin Ratio: 1.8
Reference Range: 6.0-8.5
Result: Within range
2. Bilirubin: 0.3 mg/dL
Reference Range: 97-108 mg/dL
Result: Within range
3. Chloride, Serum: 104.0 mmol/L
Reference Range: 3.5-5.2 mmol/L
Result: Within range
4. Globulin: 2.5 g/dL
Reference Range: 1.1-2.5 g/dL
Result: Within range
5. Hematocrit: 37.3%
Reference Range: 11.5-15.0%
Result: Within range
6. Hemoglobin: 12.1 g/dL
Reference Range: 0-149 g/dL
Result: Within range
7. Monocytes: 0.6 x10E3/uL
Reference Range: 140-415 x10E3/uL
Result: Above Normal High
8. Neutrophils: 1.7 x10E3/uL
Reference Range: 0.1-1.0 x10E3/uL
Result: Within range
9. Platelet Count: 145.0 x10E3/uL
Reference Range: 1.5-4.5 x10E3/uL
Result: Within range
10. Potassium, Serum: 3.9 mmol/L
Reference Range: Negative
Result: Within range
11. Protein Total: 6.7 g/dL
Reference Range: 0.0-1.2 g/dL
Result: Within range
12. Triglyceride: 164.0 mg/dL
Reference Range: 1.8-7.8 mg/dL
Result: Above Normal High
Commentary:
Fern Dugger's blood test results indicate that her liver and kidney function are normal, with no significant abnormalities. However, her monocyte count is above normal high, which may be indicative of an inflammatory process, infection, or malignancy. Further investigation and monitoring are recommended to identify the underlying cause and to assess the need for additional testing or treatment.
Recommendations:
1. Repeat a complete blood count (CBC) in 2-3 weeks to monitor monocyte count and to assess any changes in other blood cell lines.
2. Evaluate the patient's symptoms, medical history, and physical examination findings to determine the cause of the elevated monocyte count.
3. Consider additional testing, such as a C-reactive protein (CRP) or erythrocyte sedimentation rate (ESR), to assess for inflammation or infection.
4. Monitor the patient's symptoms and follow-up with her primary care physician or specialist as needed.
5. Encourage lifestyle modifications, such as quitting smoking, maintaining a healthy weight, and engaging in regular physical activity, to promote overall health and wellness.
*End of report*
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - cf9823b9-3666-4fb9-a494-608d11ae74cd (Sample 3):
CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: Brandon Chase
- Date of Birth: March 28, 2000
- Smoking Status: 0 cigaretttes per day (previous smoker)
- Table of Lab Observations (CSV rows):
HL7Text,ObservationValue,Units,ReferenceRange,AbnormalFlags,IsAbnormalValue
Albumin / Globulin Ratio,1.6,,6.0-8.5,,False
Ambig Abbrev CMP14 Default,,,34.0-44.0,,False
Ambig Abbrev LP Default,,,34.0-44.0,,False
Bilirubin,0.4,mg/dL,97-108,Above Normal High,True
"Chloride, Serum",96.0,mmol/L,3.5-5.2,,False
Globulin,2.5,g/dL,1.1-2.5,,False
Hematocrit,37.9,%,36.0-50.0,,False
Hemoglobin,12.6,g/dL,12.5-17.0,Below Normal Low,True
Monocytes,0.6,x10E3/uL,140-415,,False
Neutrophils,6.9,x10E3/uL,0.1-1.0,,False
Platelet Count,253.0,x10E3/uL,1.5-4.5,,False
"Potassium, Serum",3.7,mmol/L,Negative,,False
Protein Total,6.9,g/dL,0.0-1.2,,False
Triglyceride,170.0,mg/dL,1.8-7.8,Above Normal High,True
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as "Pathology Report" with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
Pathology Report
Patient Information
Name: Brandon Chase
Date of Birth: March 28, 2000
Smoking Status: 0 cigarettes per day (previous smoker)
Blood Test Results
1. Albumin / Globulin Ratio: 1.6
Reference Range: 6.0-8.5
Note: Within range
2. Ambig Abbrev CMP14 Default:
Reference Range:
Note: Within range
3. Ambig Abbrev LP Default:
Reference Range:
Note: Within range
4. Bilirubin: 0.4mg/dL
Reference Range: 97-108
Note: Above Normal High (Abnormal)
5. Chloride, Serum: 96.0mmol/L
Reference Range: 3.5-5.2
Note: Within range
6. Globulin: 2.5g/dL
Reference Range: 1.1-2.5
Note: Within range
7. Hematocrit: 37.9%
Reference Range: 36.0-50.0
Note: Within range
8. Hemoglobin: 12.6g/dL
Reference Range: 12.5-17.0
Note: Below Normal Low (Abnormal)
9. Monocytes: 0.6x10E3/uL
Reference Range: 140-415
Note: Within range
10. Neutrophils: 6.9x10E3/uL
Reference Range: 0.1-1.0
Note: Within range
11. Platelet Count: 253.0x10E3/uL
Reference Range: 1.5-4.5
Note: Within range
12. Potassium, Serum: 3.7mmol/L
Reference Range: Negative
Note: Within range
13. Protein Total: 6.9g/dL
Reference Range: 0.0-1.2
Note: Within range
14. Triglyceride: 170.0mg/dL
Reference Range: 1.8-7.8
Note: Above Normal High (Abnormal)
Commentary
The blood test results suggest that Brandon Chase has elevated bilirubin levels, which may indicate liver or biliary tract problems. Additionally, his hemoglobin levels are lower than expected, which may be indicative of anemia. Further testing may be necessary to determine the cause of these abnormalities.
Recommendations
- Repeat liver function tests to monitor for any changes and investigate the cause of the elevated bilirubin levels.
- Consider performing a complete blood count (CBC) with differential to investigate the cause of the low hemoglobin levels.
- Encourage Brandon Chase to adopt a healthier lifestyle, including a balanced diet, regular exercise, and avoiding smoking and excessive alcohol consumption.
- Monitor his blood pressure and cholesterol levels to ensure they remain within a healthy range.
- Recommend that Brandon Chase schedule a follow-up appointment with his primary care physician to discuss these results and any potential next steps.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - 75177378-be69-4fcb-9c3b-3981815bd9cb (Sample 4):
CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: Barbara Cubbage
- Date of Birth: December 05, 2003
- Smoking Status: 0 cigaretttes per day (non-smoker or less than 100 in lifetime)
- Table of Lab Observations (CSV rows):
HL7Text,ObservationValue,Units,ReferenceRange,AbnormalFlags,IsAbnormalValue
AST (SGOT),,,,,False
Ambig Abbrev CMP14 Default,,,,,False
Bilirubin,1.2,mg/dL,97-108,,False
Bilirubin,,,,Below Normal Low,True
"Chloride, Serum",10.0,mmol/L,3.5-5.2,,False
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as "Pathology Report" with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
Pathology Report
Patient Information:
Name: Barbara Cubbage
Date of Birth: December 05, 2003
Smoking Status: 0 cigaretttes per day (non-smoker or less than 100 in lifetime)
Blood Test Results:
1. AST (SGOT):
Observation Value: N/A
Units: N/A
Reference Range: N/A
AbnormalFlags: False
IsAbnormalValue: False
Notes: Within range
2. Ambig Abbrev CMP14 Default:
Observation Value: N/A
Units: N/A
Reference Range: N/A
AbnormalFlags: False
IsAbnormalValue: False
Notes: Within range
3. Bilirubin:
Observation Value: 1.2
Units: mg/dL
Reference Range: 97-108
AbnormalFlags: False
IsAbnormalValue: False
Notes: Below Normal Low (Note: Bilirubin is lower than the lower limit of normal)
Commentary:
The patient's blood test results suggest that her liver function (AST) is within the normal range. However, her bilirubin level is lower than the lower limit of normal, which may indicate potential liver dysfunction. Further investigation or follow-up testing may be necessary.
Recommendations:
1. Repeat liver function tests (AST, ALT, ALP) in 4-6 weeks to monitor potential liver dysfunction.
2. Monitor bilirubin levels and investigate potential underlying causes, such as hepatitis, cholestasis, or drug-induced liver injury.
3. Encourage patient to avoid alcohol and maintain a healthy weight to reduce potential liver damage.
*End of report*
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - f715878a-45d5-4f62-aa73-c21399cd89f3 (Sample 5):
CONTEXT
You are a pathologist writing concise, clinically useful pathology reports to document findings and recommendations of each laboratory test. Each LabResultGuid corresponds to a test session for one patient. Your goal is to summarize salient abnormalities, highlight patient risk factors, and provide actionable recommendations suitable for care-management follow-up.
INPUT DATA
- Patient’s name: Justin Ford
- Date of Birth: June 11, 1991
- Smoking Status: 0 cigaretttes per day (previous smoker)
- Table of Lab Observations (CSV rows):
HL7Text,ObservationValue,Units,ReferenceRange,AbnormalFlags,IsAbnormalValue
Albumin / Globulin Ratio,,,,,False
Bilirubin,0.3,mg/dL,97-108,,False
"Chloride, Serum",103.0,mmol/L,3.5-5.2,,False
Follicle Stim Hormone,216.0,mg/dL,,,False
COMMAND
- Write a pathology report which uses all the inputs above.
- Correct any spelling errors.
OUTPUT REQUIREMENTS
- Title the report as "Pathology Report" with the following four sections - Patient Information, Blood Test Results, Commentary, Recommendations.
- Under the “Patient Information” section, include name, date of birth, and smoking status.
- Under the “Blood Test Results” section, list, in numbered points, each observation with result, reference range if available, and a brief note (e.g., “Within range” / “Abnormal”).
- Under the “Commentary” section, summarise what the results suggest for risk stratification and potential next steps.
- Under the “Recommendations” section, provide practical, action-oriented follow-ups (labs to repeat, monitoring, lifestyle / care-management prompts).
CONSTRAINTS
- If no observations are available, state that clearly and focus on next steps.
- Use professional, clear language suitable for care-management and providers.
*End of prompt*
Pathology Report
Patient Information
Name: Justin Ford
Date of Birth: June 11, 1991
Smoking Status: 0 cigarettes per day (previous smoker)
Blood Test Results
1. Albumin / Globulin Ratio: Within range
2. Bilirubin: 0.3 mg/dL (Reference Range: 97-108) - Within range
3. Chloride, Serum: 103.0 mmol/L (Reference Range: 3.5-5.2) - Within range
4. Follicle Stim Hormone: 216.0 mg/dL - Abnormal (Reference Range: not available)
Commentary
The results suggest that Justin Ford has a low risk for liver disease, as indicated by his normal bilirubin levels. However, his elevated follicle stim hormone level (216.0 mg/dL) is outside the reference range and may indicate an ovarian tumor or other abnormalities in the reproductive system. Further testing is necessary to confirm these findings.
Recommendations
1. Repeat follicle stim hormone testing to confirm abnormal results.
2. Monitor bilirubin levels to ensure that they remain within normal range.
3. Provide lifestyle counseling to support a healthy weight and promote regular exercise, as these factors may reduce the risk of ovarian tumors and other reproductive cancers.
Q3b - Critique LLM advice¶
Q3b Step 1 - Analyse LLM Advice with Comparison to Actual Pathology Reports¶
# Q3b Step 1
def strip_up_to_end_of_prompt(text: str) -> str:
"""
Remove everything before and including '*End of prompt*' (case-insensitive).
If the marker is not found, return the original text (trimmed).
"""
if not isinstance(text, str):
return text
# Match any content lazily up to the first '*End of prompt*', then any trailing whitespace/newlines
# (?is) -> re.DOTALL + re.IGNORECASE
pattern = re.compile(r'(?is).*?\*end of prompt\*\s*')
m = pattern.search(text)
if m:
return text[m.end():].strip()
return text.strip()
def add_cleaned_results_column(prompts_df: pd.DataFrame,
src_col: str = "ResultsFromLLM",
dst_col: str = "ResultsFromLLM_cleaned") -> pd.DataFrame:
"""
Create a new column with content after '*End of prompt*'.
"""
out = prompts_df.copy()
out[dst_col] = out[src_col].apply(strip_up_to_end_of_prompt)
return out
prompts_df = add_cleaned_results_column(prompts_df)
# Print the written_report, followed by the ResultsFromLLM_cleaned, and iterate over each row using a for loop
for i, row in prompts_df.iterrows():
print(f"Written Report for LabResultGuid - {row['LabResultGuid']} (Sample {i + 1}):")
print()
print(row['written_report'])
print()
print("-" * 550)
print("-" * 550)
print()
print(f"Results from LLM for LabResultGuid - {row['LabResultGuid']} (Sample {i + 1}):")
print()
print(row['ResultsFromLLM_cleaned'])
if i < 4:
print()
print("-" * 550)
print("-" * 550)
print()
Written Report for LabResultGuid - fc641eaf-1dce-48bb-a004-5b2f7b93715b (Sample 1):
**Pathology Report:**
Patient: Julie Kemper
Date of Birth: September 1, 1982
Smoking Status: 0 cigarettes per day (previous smoker)
**Abnormal Results:**
1. **Urobilinogen:**
- Result: 5.2 g/dL
- Reference Range: NA
- Abnormal Flags: FALSE
2. **Triglyceride:**
- Result: 84.0 mg/dL
- Reference Range: 1.8-7.8 mg/dL
- Abnormal Flags: FALSE
3. **Cannabinoids:**
- Result: 0.7 g/dL
- Reference Range: NA
- Abnormal Flags: FALSE
4. **Urobilinogen:**
- Result: 32.0 mg/dL
- Reference Range: Yellow
- Abnormal Flags: Above Normal High
**Comments:**
The blood test results for Julie Kemper show elevated levels of urobilinogen (32.0 mg/dL), which is above the normal high range for this parameter. This abnormality may indicate liver dysfunction or hemolysis, and further investigation or monitoring may be needed.
The triglyceride level is within the normal range despite being on the higher end (84.0 mg/dL). Given the patient's previous smoking status, it is essential to monitor lipid levels regularly to assess cardiovascular risk factors.
The cannabinoid level is slightly elevated (0.7 g/dL), which may indicate recent cannabis use. Given the patient's smoking history, it is important to discuss smoking cessation and its impact on overall health.
**Recommendation:**
1. Follow up with additional liver function tests to determine the cause of elevated urobilinogen levels.
2. Monitor triglyceride levels regularly and consider lifestyle modifications to reduce cardiovascular risk factors.
3. Discuss the impact of smoking cessation on overall health and encourage lifestyle changes to improve well-being.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - fc641eaf-1dce-48bb-a004-5b2f7b93715b (Sample 1):
Pathology Report
Patient Information
Name: Cecile Wise
Date of Birth: September 01, 1982
Smoking Status: 0 cigarettes per day (non-smoker or less than 100 in lifetime)
Blood Test Results
1. Cannabinoids: 0.7 g/dL
- Reference Range: N/A
- Note: Within range
2. Thyroxine:
- Reference Range: N/A
- Note: Within range
3. Urobilinogen:
- 5.2 g/dL
- Reference Range: 0.0-1.0 g/dL
- Note: Within range
4. Triglyceride: 84.0 mg/dL
- Reference Range: 1.8-7.8 mg/dL
- Note: Elevated
5. Urobilinogen:
- 32.0 mg/dL
- Reference Range: N/A
- Note: Abnormal (above normal high)
Commentary
The blood test results for Cecile Wise suggest potential liver dysfunction, as indicated by the elevated levels of urobilinogen. The urobilinogen level of 32.0 mg/dL is above normal high and may indicate a potential biliary obstruction. This risk factor should be further evaluated through additional liver function tests and imaging studies.
Recommendations
We recommend that Cecile Wise undergo liver function tests, including alkaline phosphatase and bilirubin, to further evaluate the potential liver dysfunction. She should also be advised to avoid excessive alcohol consumption and maintain a healthy diet and lifestyle. A follow-up appointment should be scheduled to discuss the results and determine the appropriate care management plan.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Written Report for LabResultGuid - dec4ba11-fecc-4ba3-9ef2-2ec4f6f2d9d3 (Sample 2):
## Pathology Report
### Patient Information:
- **Name:** Fern Dugger
- **Date of Birth:** March 28, 1994
- **Smoking Status:** 0 cigarettes per day (non-smoker or less than 100 in lifetime)
### Abnormal Results:
1. **Monocytes:**
- **Result:** 0.6 x10E3/uL
- **Reference Range:** 140-415
- **Abnormality:** Above Normal High
- **Implications:** Elevated monocyte levels may indicate an underlying inflammatory condition or infection. Further evaluation may be needed.
2. **Triglyceride:**
- **Result:** 164.0 mg/dL
- **Reference Range:** 1.8-7.8
- **Abnormality:** Above Normal High
- **Implications:** High triglyceride levels may increase the risk of cardiovascular disease. Lifestyle modifications and dietary changes may be recommended.
### Comments:
Based on the abnormal results in the blood tests, further evaluation and follow-up may be necessary to assess any underlying conditions. Given the patient's non-smoker status, the abnormalities observed are unlikely to be directly related to smoking habits. Lifestyle modifications and consultations with other healthcare providers may be recommended to address the elevated levels observed in this report. Further diagnostic tests and monitoring may be necessary for a comprehensive assessment of the patient's health status.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - dec4ba11-fecc-4ba3-9ef2-2ec4f6f2d9d3 (Sample 2):
Pathology Report
Patient Information:
Name: Fern Dugger
Date of Birth: March 28, 1994
Smoking Status: 0 cigarettes per day (non-smoker or less than 100 in lifetime)
Blood Test Results:
1. Albumin / Globulin Ratio: 1.8
Reference Range: 6.0-8.5
Result: Within range
2. Bilirubin: 0.3 mg/dL
Reference Range: 97-108 mg/dL
Result: Within range
3. Chloride, Serum: 104.0 mmol/L
Reference Range: 3.5-5.2 mmol/L
Result: Within range
4. Globulin: 2.5 g/dL
Reference Range: 1.1-2.5 g/dL
Result: Within range
5. Hematocrit: 37.3%
Reference Range: 11.5-15.0%
Result: Within range
6. Hemoglobin: 12.1 g/dL
Reference Range: 0-149 g/dL
Result: Within range
7. Monocytes: 0.6 x10E3/uL
Reference Range: 140-415 x10E3/uL
Result: Above Normal High
8. Neutrophils: 1.7 x10E3/uL
Reference Range: 0.1-1.0 x10E3/uL
Result: Within range
9. Platelet Count: 145.0 x10E3/uL
Reference Range: 1.5-4.5 x10E3/uL
Result: Within range
10. Potassium, Serum: 3.9 mmol/L
Reference Range: Negative
Result: Within range
11. Protein Total: 6.7 g/dL
Reference Range: 0.0-1.2 g/dL
Result: Within range
12. Triglyceride: 164.0 mg/dL
Reference Range: 1.8-7.8 mg/dL
Result: Above Normal High
Commentary:
Fern Dugger's blood test results indicate that her liver and kidney function are normal, with no significant abnormalities. However, her monocyte count is above normal high, which may be indicative of an inflammatory process, infection, or malignancy. Further investigation and monitoring are recommended to identify the underlying cause and to assess the need for additional testing or treatment.
Recommendations:
1. Repeat a complete blood count (CBC) in 2-3 weeks to monitor monocyte count and to assess any changes in other blood cell lines.
2. Evaluate the patient's symptoms, medical history, and physical examination findings to determine the cause of the elevated monocyte count.
3. Consider additional testing, such as a C-reactive protein (CRP) or erythrocyte sedimentation rate (ESR), to assess for inflammation or infection.
4. Monitor the patient's symptoms and follow-up with her primary care physician or specialist as needed.
5. Encourage lifestyle modifications, such as quitting smoking, maintaining a healthy weight, and engaging in regular physical activity, to promote overall health and wellness.
*End of report*
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Written Report for LabResultGuid - cf9823b9-3666-4fb9-a494-608d11ae74cd (Sample 3):
**Pathology Report:**
**Patient Information:**
- Name: Brandon Chase
- Date of Birth: March 28, 2000
- Smoking Status: 0 cigarettes per day (previous smoker)
**Abnormal Results:**
1. **Bilirubin:** The bilirubin level is 0.4 mg/dL, which is above the normal range of 97-108 mg/dL. This indicates a potential issue with liver function.
2. **Hemoglobin:** The hemoglobin level is 12.6 g/dL, below the normal range of 12.5-17.0 g/dL. This could indicate anemia or another underlying health condition.
3. **Triglycerides:** The triglyceride level is 170.0 mg/dL, above the normal range of 1.8-7.8 mg/dL. Elevated triglycerides are a risk factor for cardiovascular disease.
**Comments:**
Based on the abnormal results in Brandon Chase's blood test, further evaluation and follow-up are recommended. It is important to investigate the potential liver function issues indicated by the elevated bilirubin level. Additionally, the low hemoglobin level suggests the possibility of anemia that requires attention. The high triglyceride level raises concerns about cardiovascular health and may require lifestyle modifications or medical intervention.
Given Brandon's history of smoking, it is important to consider the impact of smoking on these abnormal results. Smoking can exacerbate liver function issues, contribute to anemia, and worsen cardiovascular risk factors like elevated triglycerides. Therefore, smoking cessation should be strongly encouraged to improve overall health outcomes.
Further testing, evaluation, and lifestyle modifications should be discussed with Brandon Chase to address these abnormal blood test results effectively.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - cf9823b9-3666-4fb9-a494-608d11ae74cd (Sample 3):
Pathology Report
Patient Information
Name: Brandon Chase
Date of Birth: March 28, 2000
Smoking Status: 0 cigarettes per day (previous smoker)
Blood Test Results
1. Albumin / Globulin Ratio: 1.6
Reference Range: 6.0-8.5
Note: Within range
2. Ambig Abbrev CMP14 Default:
Reference Range:
Note: Within range
3. Ambig Abbrev LP Default:
Reference Range:
Note: Within range
4. Bilirubin: 0.4mg/dL
Reference Range: 97-108
Note: Above Normal High (Abnormal)
5. Chloride, Serum: 96.0mmol/L
Reference Range: 3.5-5.2
Note: Within range
6. Globulin: 2.5g/dL
Reference Range: 1.1-2.5
Note: Within range
7. Hematocrit: 37.9%
Reference Range: 36.0-50.0
Note: Within range
8. Hemoglobin: 12.6g/dL
Reference Range: 12.5-17.0
Note: Below Normal Low (Abnormal)
9. Monocytes: 0.6x10E3/uL
Reference Range: 140-415
Note: Within range
10. Neutrophils: 6.9x10E3/uL
Reference Range: 0.1-1.0
Note: Within range
11. Platelet Count: 253.0x10E3/uL
Reference Range: 1.5-4.5
Note: Within range
12. Potassium, Serum: 3.7mmol/L
Reference Range: Negative
Note: Within range
13. Protein Total: 6.9g/dL
Reference Range: 0.0-1.2
Note: Within range
14. Triglyceride: 170.0mg/dL
Reference Range: 1.8-7.8
Note: Above Normal High (Abnormal)
Commentary
The blood test results suggest that Brandon Chase has elevated bilirubin levels, which may indicate liver or biliary tract problems. Additionally, his hemoglobin levels are lower than expected, which may be indicative of anemia. Further testing may be necessary to determine the cause of these abnormalities.
Recommendations
- Repeat liver function tests to monitor for any changes and investigate the cause of the elevated bilirubin levels.
- Consider performing a complete blood count (CBC) with differential to investigate the cause of the low hemoglobin levels.
- Encourage Brandon Chase to adopt a healthier lifestyle, including a balanced diet, regular exercise, and avoiding smoking and excessive alcohol consumption.
- Monitor his blood pressure and cholesterol levels to ensure they remain within a healthy range.
- Recommend that Brandon Chase schedule a follow-up appointment with his primary care physician to discuss these results and any potential next steps.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Written Report for LabResultGuid - 75177378-be69-4fcb-9c3b-3981815bd9cb (Sample 4):
**Pathology Report:**
- **AST (SGOT) (Aspartate Aminotransferase):** The result is not available. This test is often used to evaluate liver function and can be elevated in conditions like liver disease or damage. Without this value, it is difficult to assess the patient's liver health.
- **CMP14 Default (Comprehensive Metabolic Panel):** The result is not available, so we cannot comment on the complete metabolic profile of the patient.
- **Bilirubin:** The patient's total bilirubin level is 1.2 mg/dL, which falls within the reference range of 0.2-1.2 mg/dL. However, there is an additional abnormal result for bilirubin indicating "below normal low," but the actual value is not available. Low bilirubin levels can be seen in conditions like anemia, malnutrition, or liver disease.
- **Chloride, Serum:** The patient's serum chloride level is 10.0 mmol/L, which is above the reference range of 3.5-5.2 mmol/L. Elevated chloride levels can sometimes be seen in conditions like dehydration or kidney problems.
**Comments:**
Given the incomplete and abnormal test results, further evaluation is recommended to assess the patient's liver function and overall metabolic status. It is important to follow up with additional testing to determine the reason for the abnormal bilirubin levels and elevated chloride levels. As the patient is a non-smoker, smoking is not likely a factor in these abnormal results. Please consult with the ordering physician for additional investigations and management as necessary.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - 75177378-be69-4fcb-9c3b-3981815bd9cb (Sample 4):
Pathology Report
Patient Information:
Name: Barbara Cubbage
Date of Birth: December 05, 2003
Smoking Status: 0 cigaretttes per day (non-smoker or less than 100 in lifetime)
Blood Test Results:
1. AST (SGOT):
Observation Value: N/A
Units: N/A
Reference Range: N/A
AbnormalFlags: False
IsAbnormalValue: False
Notes: Within range
2. Ambig Abbrev CMP14 Default:
Observation Value: N/A
Units: N/A
Reference Range: N/A
AbnormalFlags: False
IsAbnormalValue: False
Notes: Within range
3. Bilirubin:
Observation Value: 1.2
Units: mg/dL
Reference Range: 97-108
AbnormalFlags: False
IsAbnormalValue: False
Notes: Below Normal Low (Note: Bilirubin is lower than the lower limit of normal)
Commentary:
The patient's blood test results suggest that her liver function (AST) is within the normal range. However, her bilirubin level is lower than the lower limit of normal, which may indicate potential liver dysfunction. Further investigation or follow-up testing may be necessary.
Recommendations:
1. Repeat liver function tests (AST, ALT, ALP) in 4-6 weeks to monitor potential liver dysfunction.
2. Monitor bilirubin levels and investigate potential underlying causes, such as hepatitis, cholestasis, or drug-induced liver injury.
3. Encourage patient to avoid alcohol and maintain a healthy weight to reduce potential liver damage.
*End of report*
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Written Report for LabResultGuid - f715878a-45d5-4f62-aa73-c21399cd89f3 (Sample 5):
**Pathology Report:**
Patient Name: Justin Ford
Date of Birth: June 11, 1991
Smoking Status: 0 cigarettes per day (previous smoker)
**Abnormal Results:**
1. **Follicle Stimulating Hormone:**
- Observation Value: 216.0 mg/dL
- Reference Range: NA
- Abnormal: FALSE
**Comments:** The Follicle Stimulating Hormone level is significantly elevated. This could indicate issues with reproductive health or function. Further evaluation may be necessary to determine the underlying cause.
**Comments:**
Given the elevated Follicle Stimulating Hormone level in the blood test results, it is recommended that Justin undergo further evaluation by a reproductive endocrinologist to assess his fertility health and potential reproductive issues. It is important for him to discuss these results and the implications with a healthcare provider. Additionally, since Justin is a previous smoker, it is important to consider how past smoking habits may have contributed to any potential reproductive health issues. Smoking can have detrimental effects on fertility and reproductive function, so discussing the cessation of smoking with a healthcare provider is also advisable.
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Results from LLM for LabResultGuid - f715878a-45d5-4f62-aa73-c21399cd89f3 (Sample 5):
Pathology Report
Patient Information
Name: Justin Ford
Date of Birth: June 11, 1991
Smoking Status: 0 cigarettes per day (previous smoker)
Blood Test Results
1. Albumin / Globulin Ratio: Within range
2. Bilirubin: 0.3 mg/dL (Reference Range: 97-108) - Within range
3. Chloride, Serum: 103.0 mmol/L (Reference Range: 3.5-5.2) - Within range
4. Follicle Stim Hormone: 216.0 mg/dL - Abnormal (Reference Range: not available)
Commentary
The results suggest that Justin Ford has a low risk for liver disease, as indicated by his normal bilirubin levels. However, his elevated follicle stim hormone level (216.0 mg/dL) is outside the reference range and may indicate an ovarian tumor or other abnormalities in the reproductive system. Further testing is necessary to confirm these findings.
Recommendations
1. Repeat follicle stim hormone testing to confirm abnormal results.
2. Monitor bilirubin levels to ensure that they remain within normal range.
3. Provide lifestyle counseling to support a healthy weight and promote regular exercise, as these factors may reduce the risk of ovarian tumors and other reproductive cancers.
Q3b Step 2 - Final Words for Sharing with Betahelf¶
Actionable Insights
Structure and completeness. The LLM reliably delivered the required headings and covered the inputs provided. However, where the raw lab observations lacked clear numeric reference ranges or used atypical formats, the LLM sometimes produced generic or inconsistent result interpretations.
- Action: Add a guardrail: "If a reference is missing or non-numeric, state 'Reference range not supplied' and do not judge normality."
Specifics vs. generality. Humans often included precise rationale tied to the particular assays (e.g. context from prior trends or obvious clinical caveats). The LLM was strong at clear summaries but tended to generalise (e.g. broad risk statements not anchored to thresholds).
- Action: Provide a compact "evidence card" per observation (result, reference range, change from prior if available) and require the LLM to cite these lines explicitly in the "Commentary" section.
Actionability for care-management. Humans often translated abnormal findings into concrete next steps (e.g. which test to repeat). LLM recommendations were readable and safe but sometimes non-committal ("consider additional testing").
- Action: Insert a policy snippet with Betahelf-approved escalation and follow-up intervals (e.g. "If X > Y, repeat in Z weeks").
Does not follow prompt exactly. The LLM sometimes misses instructions in the prompt, such as to "correct all spelling errors".
- Action: Check the LLM-generated reports before using them as spelling mistakes can sometimes change the entire meaning of a word, phrase, or sentence and result in incorrect diagnoses, which can be extremely detrimental to Betahelf and its patients.
Strengths and Weaknesses of the LLM Output by Stakeholder
Patients:
Strengths:
- Highly consistent format, readable language, clear list of next steps supports outreach and documentation.
Weaknesses:
- Generic risk language can over-alarm or under-specify, resulting in unnecessary follow-ups.
Pathologists & Clinicians:
Strengths:
- Time savings for routine panels; starting draft reduces cognitive load; standardises terminology.
Weaknesses:
- Occasional over-interpretation of ambiguous ranges.
Compliance:
Strengths:
- Templated outputs facilitate auditability and version control.
Weaknesses:
- Risk of hallucination; if patient sub-group differences exist in missing data or access, commentary could skew perceived risk.
Betahelf's Management Team:
Strengths:
- Scalable coverage, reduced turnaround time, standardisation, implying potential cost and quality gains.
Weaknesses:
- Compute cost, governance overhead, and change-management within clinical teams.
Recommendations for Betahelf
Keep the LLM as a pathology report co-pilot and decision support tool, not a standalone author, and require a pathologist sign-off.
- The LLM can act as a cross-checker as well. For example, the first sample pathology report had a different patient name written by the human. Further checks found that the name produced by the LLM was the correct one, assuming that the raw data is correct.
Tighten the LLM's inputs and enforce output rules.
Track key KPIs, such as report turnaround times and any patient sub-group discrimination.
- This preserves the LLM's strengths (speed, consistency, readability) while mitigating clinical and compliance risk, aligning with Betahelf's goals of safer, earlier, and more scalable patient intervention.